Exemplo n.º 1
0
def train(args, model, datasets, all_dataset_sampler, task_id=-1):

    args.train_batch_size = args.mini_batch_size * max(1, args.n_gpu)
    # train_sampler = all_dataset_sampler
    # train_dataloader  = DataLoader(datasets, sampler=train_sampler)
    no_decay = ["bias", "LayerNorm.weight"]
    alpha_sets = ["alpha_list"]

    t_total = len(datasets) * args.num_train_epochs // (
        args.gradient_accumulation_steps * args.train_batch_size)

    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in (no_decay + alpha_sets))
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in alpha_sets)
        ],
        'lr':
        1e-1
    }]

    if args.warmup_ratio > 0:
        args.warmup_steps = int(t_total * args.warmup_ratio)

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
        # model = DDP(model, device_ids=list(range(args.n_gpu)))

    # Distributed training (should be after apex fp16 initialization)
    # TODO Need change sampler !
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    logger.info("***** Running training *****")
    logger.info(" Num Epochs = %d", args.num_train_epochs)
    logger.info(" Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info(" Total batch size = %d " %
                (args.train_batch_size * args.gradient_accumulation_steps))
    logger.info(" Total training steps = %d " % t_total)
    logger.info(" Warmup steps = %d " % args.warmup_steps)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    set_seed(args)
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=False)

    step = 0

    if all_dataset_sampler == None:
        train_sampler = RandomSampler(datasets)
        train_dataloader = DataLoader(datasets,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)
    else:
        train_dataloader = DataLoader(datasets, sampler=all_dataset_sampler)
    # TODO Need be changed if in dist training

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=False)
        model.train()
        iter_bar = tqdm(train_dataloader,
                        desc="Iter(loss=X.XXX, lr=X.XXXXXXXX)")
        for step, batch in enumerate(iter_bar):
            input_ids = batch[0].squeeze().long().to(args.device)
            input_mask = batch[1].squeeze().long().to(args.device)
            segment_ids = batch[2].squeeze().long().to(args.device)
            label_ids = batch[3].squeeze().long().to(args.device)
            head_ids = batch[4].squeeze().long().to(args.device)

            task_id = batch[5].squeeze().long().to(args.device)

            assert batch[5].max() == batch[5].min()
            task_id = batch[5].max().unsqueeze(0)
            inputs = {
                "input_ids": input_ids,
                "attention_mask": input_mask,
                "token_type_ids": segment_ids,
                "labels": label_ids,
                "heads": head_ids,
                "task_id": task_id
            }

            # if args.n_gpu>1:
            #     device_ids = list(range(args.n_gpu))
            #     outputs = data_parallel(model,inputs, device_ids)
            # else:
            outputs = model(**inputs)
            loss = outputs[0]

            if args.do_task_embedding:
                alpha = outputs[0]
                loss = outputs[1]

            elif args.do_alpha:
                loss = outputs[1]

            if args.n_gpu > 1:
                loss = loss.mean()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            logging_loss += loss.item()
            # writer.add_scalar("Loss/train", loss.item().data, global_step)
            # global_step += 1

            if args.local_rank in [-1, 0]:
                iter_bar.set_description('Iter (loss=%5.3f) lr=%9.7f' %
                                         (loss.item(), scheduler.get_lr()[0]))

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # call optimizer.step() first !

                model.zero_grad()
                global_step += 1

                if global_step % 100 == 0:
                    logger.info("%d - %d loss = %f" %
                                (global_step - 99, global_step, logging_loss))
                    logging_loss = 0.0

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, "module"
                    ) else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    return model
Exemplo n.º 2
0
def train(args, model, datasets, all_dataset_sampler, task_id=-1):

    args.train_batch_size = args.mini_batch_size * max(1, args.n_gpu)
    # train_sampler = all_dataset_sampler
    # train_dataloader  = DataLoader(datasets, sampler=train_sampler)
    no_decay = ["bias", "LayerNorm.weight"]
    alpha_sets = ["alpha_list"]

    t_total = sum(len(x) for x in datasets) //args.gradient_accumulation_steps

    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in (no_decay + alpha_sets))], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in alpha_sets)], 'lr':1e-1}
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
    
    
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    if args.n_gpu > 1:
        
        model = torch.nn.DataParallel(model)
        # model = DDP(model, device_ids=list(range(args.n_gpu)))

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)
    
        


    logger.info("***** Running training *****")
    logger.info(" Num Epochs = %d", args.num_train_epochs)
    logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    set_seed(args)
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=False)

    step = 0

    for _ in train_iterator:
        if all_dataset_sampler == None:
            train_sampler = RandomSampler(datasets)
            train_dataloader = DataLoader(datasets, sampler=train_sampler, batch_size=args.train_batch_size)
        else:
            train_dataloader = DataLoader(datasets, sampler=all_dataset_sampler)
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=False)
        model.train()
        iter_bar = tqdm(train_dataloader, desc="Iter(loss=X.XXX)")
        for step, batch in enumerate(iter_bar):
            input_ids = batch[0].squeeze().long().to(args.device)
            input_mask = batch[1].squeeze().long().to(args.device)
            segment_ids = batch[2].squeeze().long().to(args.device)
            label_ids = batch[3].squeeze().long().to(args.device)
            head_ids = batch[4].squeeze().long().to(args.device)
            
            task_id = batch[5].squeeze().long().to(args.device)

            assert batch[5].max() == batch[5].min()
            task_id = batch[5].max().unsqueeze(0)
            inputs = {"input_ids":input_ids, 
                      "attention_mask":input_mask,
                      "token_type_ids":segment_ids,
                      "labels":label_ids,
                      "heads":head_ids,
                      "task_id":task_id}
            
            # if args.n_gpu>1:
            #     device_ids = list(range(args.n_gpu))
            #     outputs = data_parallel(model,inputs, device_ids)
            # else:

            # ============= do adversarial training ===============
            
            ## step 0: set the regularization loss_function

            r_loss_func = nn.KLDivLoss(reduction="none")

            ## step 1: add random bias to inputs embeds
            if isinstance(model, torch.nn.DataParallel):
                embeds_init = model.module.bert.embeddings.word_embeddings(input_ids)
            else:
                embeds_init = model.bert.embeddings.word_embeddings(input_ids)
            
            if args.adv_init_mag > 0:
                input_mask = input_mask.to(embeds_init)
                input_lengths = torch.sum(input_mask, 1)

                delta = torch.zeros_like(embeds_init).uniform_(-1, 1)*input_mask.unsqueeze(2)
                dims = input_lengths.to(embeds_init)
                mag = args.adv_init_mag / torch.sqrt(dims) # mag is the p
                delta = (delta*mag.view(-1, 1, 1)).detach()
            else:
                delta = torch.zeros_like(embeds_init)
            

            ## step 2: forward adversarial sample !

            adv_inputs = copy.deepcopy(inputs)
            adv_inputs["inputs_embeds"] = delta + embeds_init
            adv_inputs["input_ids"] = None
            adv_outputs = model(**adv_inputs)
            if type(model.classifier_list[task_id]) == DeepBiAffineDecoderV2: # do parsing
                adv_loss = adv_outputs[0]
                adv_logits_arc = adv_outputs[1]
                adv_logits_label = adv_outputs[2]
            else:
                adv_loss = adv_outputs[0]
                adv_logits = adv_outputs[1]
            
            
            ## step 3: forward raw sample and caculate divergence

            outputs = model(**inputs)
            loss = outputs[0]
        
            if type(model.classifier_list[task_id]) == DeepBiAffineDecoderV2: # do parsing
                loss = outputs[0]
                logits_arc = outputs[1]
                logits_label = outputs[2]

                r_loss_arc = r_loss_func(F.log_softmax(adv_logits_arc.float(), dim=-1), F.softmax(logits_arc.float(), dim=-1)).sum(dim=-1).mean()

                r_loss_label = r_loss_func(F.log_softmax(adv_logits_label.float(), dim=-1),F.softmax( logits_label.float(), dim=-1)).sum(dim=-1).mean()

                r_loss = r_loss_arc + r_loss_label

            else:
                loss = outputs[0]
                logits = outputs[1]

                r_loss = r_loss_func(F.log_softmax(adv_logits.float(), dim=-1), F.softmax(logits.float(), dim=-1)).sum(dim=-1).mean()

            
            
            ## step 4: maximize the divergence and minimize the normal loss
            r_loss = torch.reciporacal(r_loss)
            loss = loss + args.gamma*r_loss 
            

            if args.do_task_embedding:
                alpha = outputs[0]
                loss = outputs[1]

            elif args.do_alpha:
                loss = outputs[1]
            
            if args.n_gpu > 1:
                loss = loss.mean()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            
            tr_loss += loss.item()
            # writer.add_scalar("Loss/train", loss.item().data, global_step)
            # global_step += 1

            if args.local_rank in [-1, 0]:
                iter_bar.set_description("Iter (loss=%5.3f)" % loss.item())
         
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                
                scheduler.step()
                optimizer.step()

                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(model, "module") else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break
    
    return model
Exemplo n.º 3
0
def train(args, model, datasets, all_dataset_sampler, task_id=-1):

    args.train_batch_size = args.mini_batch_size * max(1, args.n_gpu)
    # train_sampler = all_dataset_sampler
    # train_dataloader  = DataLoader(datasets, sampler=train_sampler)
    no_decay = ["bias", "LayerNorm.weight"]
    alpha_sets = ["alpha_list"]

    t_total = sum(len(x)
                  for x in datasets) // (args.gradient_accumulation_steps *
                                         args.train_batch_size)

    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in (no_decay + alpha_sets))
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in alpha_sets)
        ],
        'lr':
        1e-1
    }]

    if args.warmup_ratio > 0:
        args.warmup_steps = int(t_total * args.warmup_ratio)

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    if args.n_gpu > 1:

        model = torch.nn.DataParallel(model)
        # model = DDP(model, device_ids=list(range(args.n_gpu)))

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    logger.info("***** Running training *****")
    logger.info(" Num Epochs = %d", args.num_train_epochs)
    logger.info(" Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    set_seed(args)
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=False)

    step = 0

    for _ in train_iterator:
        if all_dataset_sampler == None:
            train_sampler = RandomSampler(datasets)
            train_dataloader = DataLoader(datasets,
                                          sampler=train_sampler,
                                          batch_size=args.train_batch_size)
        else:
            train_dataloader = DataLoader(datasets,
                                          sampler=all_dataset_sampler)
        # epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=False)
        model.train()
        iter_bar = tqdm(train_dataloader,
                        desc="Iter(loss=X.XXX, lr=X.XXXXXXXX))")
        for step, batch in enumerate(iter_bar):
            input_ids = batch[0].squeeze().long().to(args.device)
            input_mask = batch[1].squeeze().long().to(args.device)
            segment_ids = batch[2].squeeze().long().to(args.device)
            label_ids = batch[3].squeeze().long().to(args.device)
            head_ids = batch[4].squeeze().long().to(args.device)

            task_id = batch[5].squeeze().long().to(args.device)

            assert batch[5].max() == batch[5].min()
            task_id = batch[5].max().unsqueeze(0)
            inputs = {
                "input_ids": input_ids,
                "attention_mask": input_mask,
                "token_type_ids": segment_ids,
                "labels": label_ids,
                "heads": head_ids,
                "task_id": task_id
            }

            # ======================= Code for adversarial training ====================

            if isinstance(model, torch.nn.DataParallel):
                embeds_init = model.module.bert.embeddings.word_embeddings(
                    input_ids)
            else:
                embeds_init = model.bert.embeddings.word_embeddings(input_ids)

            if args.adv_init_mag > 0:
                input_mask = input_mask.to(embeds_init)
                input_lengths = torch.sum(input_mask, 1)

                if args.norm_type == "l2":
                    delta = torch.zeros_like(embeds_init).uniform_(
                        -1, 1) * input_mask.unsqueeze(2)
                    dims = input_lengths * embeds_init.size(-1)
                    mag = args.adv_init_mag / torch.sqrt(dims)
                    delta = (delta * mag.view(-1, 1, 1)).detach()
                elif args.norm_type == "linf":
                    delta = torch.zeros_like(embeds_init).uniform_(
                        -args.adv_init_mag,
                        args.adv_init_mag) * input_mask.unsqueeze(2)
            else:
                delta = torch.zeros_like(embeds_init)

            for astep in range(args.adv_steps):
                # (0) forward
                delta.requires_grad_()
                inputs["inputs_embeds"] = delta + embeds_init
                inputs["input_ids"] = None

                outputs = model(**inputs)
                loss = outputs[0]

                # (1) backward
                if args.n_gpu > 1:
                    loss = loss.mean()
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gredient_accumulation_steps

                loss = loss / args.adv_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scale_loss:
                        scale_loss.backward()
                else:
                    loss.backward()

                # (2) get gradient on delta
                delta_grad = delta.grad.clone().detach()

                # (3) update and clip
                if args.norm_type == "l2":
                    denorm = torch.norm(delta_grad.view(
                        delta_grad.size(0), -1),
                                        dim=1).view(-1, 1, 1)
                    denorm = torch.clamp(denorm, min=1e-8)
                    delta = (delta +
                             args.adv_lr * delta_grad / denorm).detach()
                    if args.adv_max_norm > 0:
                        delta_norm = torch.norm(delta.view(delta.size(0),
                                                           -1).float(),
                                                p=2,
                                                dim=1).detach()
                        exceed_mask = (delta_norm >
                                       args.adv_max_norm).to(embeds_init)
                        reweights = (args.adv_max_norm /delta_norm * exceed_mask \
                                     + (1 - exceed_mask)).view(-1, 1 ,1)

                        delta = (delta * reweights).detach()
                elif args.norm_type == "linf":
                    denorm = torch.norm(delta_grad.view(
                        delta_grad.size(0), -1),
                                        dim=1,
                                        p=float("inf")).view(-1, 1, 1)
                    denorm = torch.clamp(denorm, min=1e-8)
                    delta = (delta +
                             args.adv_lr * delta_grad / denorm).detach()
                    if args.adv_max_norm > 0:
                        delta = torch.clamp(delta, -args.adv_max_norm,
                                            args.adv_max_norm).detach()
                else:

                    print("Norm type {} not specified.".format(args.norm_type))
                    exit()

                if isinstance(model, torch.nn.DataParallel):
                    embeds_init = model.module.bert.embeddings.word_embeddings(
                        input_ids)
                else:
                    embeds_init = model.bert.embeddings.word_embeddings(
                        input_ids)

            # ====================== End(2)  ===============================================

            # if args.n_gpu>1:
            #     device_ids = list(range(args.n_gpu))
            #     outputs = data_parallel(model,inputs, device_ids)
            # else:
            # outputs = model(**inputs)
            # loss = outputs[0]

            if args.do_task_embedding:
                alpha = outputs[0]
                loss = outputs[1]

            elif args.do_alpha:
                loss = outputs[1]

            # if args.n_gpu > 1:
            #     loss = loss.mean()
            # if args.gradient_accumulation_steps > 1:
            #     loss = loss / args.gradient_accumulation_steps

            # if args.fp16:
            #     with amp.scale_loss(loss, optimizer) as scaled_loss:
            #         scaled_loss.backward()
            # else:
            #     loss.backward()

            # tr_loss += loss.item()
            # writer.add_scalar("Loss/train", loss.item().data, global_step)
            # global_step += 1

            if args.local_rank in [-1, 0]:
                iter_bar.set_description("Iter (loss=%5.3f, lr=%9.7f)" %
                                         (loss.item(), scheduler.get_lr()[0]))

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()

                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, "module"
                    ) else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                iter_bar.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    return model
Exemplo n.º 4
0
def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [{
        "params": [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        "weight_decay":
        args.weight_decay
    }, {
        "params": [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        "weight_decay":
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            if args.model_type != "distilbert":
                inputs["token_type_ids"]: batch[2] if args.model_type in [
                    "bert", "xlnet"
                ] else None  # XLM and RoBERTa don"t use segment_ids

            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in pytorch-transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                scheduler.step()  # Update learning rate schedule
                optimizer.step()
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results, _ = evaluate(args,
                                              model,
                                              tokenizer,
                                              labels,
                                              pad_token_label_id,
                                              mode="dev")
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, "module"
                    ) else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step