def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate_fn)

    if args.max_steps > 0:
        num_training_steps = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        num_training_steps = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs
    args.warmup_steps = int(num_training_steps * args.warmup_proportion)
    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    # optimizer = Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    optimizer = AdamW(params=optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=num_training_steps)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", num_training_steps)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    seed_everything(
        args.seed
    )  # Added here for reproductibility (even between python 2 and 3)
    for _ in range(int(args.num_train_epochs)):
        pbar = ProgressBar(n_total=len(train_dataloader), desc='Training')
        for step, batch in enumerate(train_dataloader):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'input_ids': batch[0],
                'attention_mask': batch[1],
                'labels': batch[3],
                'token_type_ids': batch[2]
            }
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

            if args.local_rank in [
                    -1, 0
            ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                # Log metrics
                if args.local_rank == -1:  # Only evaluate when single GPU otherwise metrics may not average well
                    evaluate(args, model, tokenizer)

            if args.local_rank in [
                    -1, 0
            ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                # Save model checkpoint
                output_dir = os.path.join(args.output_dir,
                                          'checkpoint-{}'.format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = model.module if hasattr(
                    model, 'module'
                ) else model  # Take care of distributed/parallel training
                model_to_save.save_pretrained(output_dir)
                torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                logger.info("Saving model checkpoint to %s", output_dir)
            pbar(step, {'loss': loss.item()})
        print(" ")
        if 'cuda' in str(args.device):
            torch.cuda.empty_cache()
    return global_step, tr_loss / global_step
def main():
    parser = ArgumentParser()
    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default="dataset",
        type=str,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--config_path",
                        default="prev_trained_model/electra_small/config.json",
                        type=str)
    parser.add_argument("--vocab_path",
                        default="prev_trained_model/electra_small/vocab.txt",
                        type=str)
    parser.add_argument(
        "--output_dir",
        default="outputs",
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument("--model_path",
                        default='prev_trained_model/electra_small',
                        type=str)
    parser.add_argument('--data_name', default='electra', type=str)
    parser.add_argument(
        "--file_num",
        type=int,
        default=10,
        help="Number of dynamic masking to pregenerate (with different masks)")
    parser.add_argument(
        "--reduce_memory",
        action="store_true",
        help=
        "Store training data as on-disc memmaps to massively reduce memory usage"
    )
    parser.add_argument("--epochs",
                        type=int,
                        default=4,
                        help="Number of epochs to train for")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")

    parser.add_argument('--num_eval_steps', default=100)
    parser.add_argument('--num_save_steps', default=2000)
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--train_batch_size",
                        default=128,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--gen_weight",
                        default=1.0,
                        type=float,
                        help='masked language modeling / generator loss')
    parser.add_argument("--disc_weight",
                        default=50,
                        type=float,
                        help='discriminator loss')
    parser.add_argument('--untied_generator',
                        action='store_true',
                        help='tie all generator/discriminator weights?')
    parser.add_argument('--temperature',
                        default=0,
                        type=float,
                        help='temperature for sampling from generator')
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument('--max_grad_norm', default=1.0, type=float)
    parser.add_argument("--learning_rate",
                        default=0.000176,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O2',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--continue_train',
                        default='',
                        help="continue train path")
    args = parser.parse_args()

    args.data_dir = Path(args.data_dir)
    args.output_dir = Path(args.output_dir)

    pregenerated_data = args.data_dir / "corpus/train"
    init_logger(log_file=str(args.output_dir / "train_albert_model.log"))
    assert pregenerated_data.is_dir(), \
        "--pregenerated_data should point to the folder of files made by prepare_lm_data_mask.py!"

    samples_per_epoch = 0
    for i in range(args.file_num):
        data_file = pregenerated_data / f"{args.data_name}_file_{i}.json"
        metrics_file = pregenerated_data / f"{args.data_name}_file_{i}_metrics.json"
        if data_file.is_file() and metrics_file.is_file():
            metrics = json.loads(metrics_file.read_text())
            samples_per_epoch += metrics['num_training_examples']
        else:
            if i == 0:
                exit("No training data was found!")
            print(
                f"Warning! There are fewer epochs of pregenerated data ({i}) than training epochs ({args.epochs})."
            )
            print(
                "This script will loop over the available data, but training diversity may be negatively impacted."
            )
            break
    logger.info(f"samples_per_epoch: {samples_per_epoch}")
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device(f"cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        args.n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        f"device: {device} , distributed training: {bool(args.local_rank != -1)}, 16-bits training: {args.fp16}"
    )

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            f"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps}, should be >= 1"
        )
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    seed_everything(args.seed)
    tokenizer = BertTokenizer.from_pretrained(args.vocab_path,
                                              do_lower_case=args.do_lower_case)
    total_train_examples = samples_per_epoch * args.epochs

    num_train_optimization_steps = int(total_train_examples /
                                       args.train_batch_size /
                                       args.gradient_accumulation_steps)
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        )
    args.warmup_steps = int(num_train_optimization_steps *
                            args.warmup_proportion)

    bert_config = ElectraConfig.from_pretrained(args.config_path,
                                                gen_weight=args.gen_weight,
                                                temperature=args.temperature,
                                                disc_weight=args.disc_weight)
    model = ElectraForPreTraining(config=bert_config)

    if args.continue_train:
        print(f"Continue train from {args.continue_train}")
        model = model.from_pretrained(args.continue_train)
    elif args.model_path:
        print("载入预训练模型")
        model.generator = AutoModel.from_pretrained(args.model_path + "/G")
        model.electra = AutoModel.from_pretrained(args.model_path + "/D")

    # print(model)
    model.to(device)
    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        args.weight_decay
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(params=optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=num_train_optimization_steps)
    # optimizer = Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    # if args.model_path:
    #     optimizer.load_state_dict(torch.load(args.model_path + "/optimizer.bin"))
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    if args.n_gpu > 1:
        # model = BalancedDataParallel(gpu0_bsz=32,dim=0,model).to(device)
        model = torch.nn.DataParallel(model)

    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)
    global_step = 0
    g_metric = LMAccuracy()
    d_metric = AccuracyThresh()
    tr_g_acc = AverageMeter()
    tr_d_acc = AverageMeter()
    tr_loss = AverageMeter()
    tr_g_loss = AverageMeter()
    tr_d_loss = AverageMeter()

    train_logs = {}
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {total_train_examples}")
    logger.info(f"  Batch size = {args.train_batch_size}")
    logger.info(f"  Num steps = {num_train_optimization_steps}")
    logger.info(f"  warmup_steps = {args.warmup_steps}")
    logger.info(f"  Num workable gpus = {args.n_gpu}")

    start_time = time.time()
    seed_everything(args.seed)  # Added here for reproducibility
    for epoch in range(args.epochs):
        for idx in range(args.file_num):
            epoch_dataset = PregeneratedDataset(
                file_id=idx,
                training_path=pregenerated_data,
                tokenizer=tokenizer,
                reduce_memory=args.reduce_memory,
                data_name=args.data_name)
            if args.local_rank == -1:
                train_sampler = RandomSampler(epoch_dataset)
            else:
                train_sampler = DistributedSampler(epoch_dataset)
            train_dataloader = DataLoader(epoch_dataset,
                                          sampler=train_sampler,
                                          batch_size=args.train_batch_size)
            model.train()
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(train_dataloader):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, lm_label_ids = batch
                outputs = model(input_ids=input_ids,
                                token_type_ids=segment_ids,
                                attention_mask=input_mask,
                                masked_lm_labels=lm_label_ids)
                loss, g_loss, d_loss, d_logits, g_logits, is_replaced_label = outputs

                active_indices = input_mask.view(-1) == 1
                active_logits = d_logits.view(-1)[active_indices]
                active_labels = is_replaced_label.view(-1)[active_indices]

                g_metric(logits=g_logits.view(-1, bert_config.vocab_size),
                         target=lm_label_ids.view(-1))
                d_metric(logits=active_logits.view(-1, 1),
                         target=active_labels)

                if args.n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                    g_loss = g_loss.mean()
                    d_loss = d_loss.mean()

                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                nb_tr_steps += 1
                tr_g_acc.update(g_metric.value(), n=input_ids.size(0))
                tr_d_acc.update(d_metric.value(), n=input_ids.size(0))

                tr_loss.update(loss.item(), n=1)
                tr_g_loss.update(g_loss.item(), n=1)
                tr_d_loss.update(d_loss.item(), n=1)

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)
                    scheduler.step()
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if global_step % args.num_eval_steps == 0:
                    now = time.time()
                    eta = now - start_time
                    if eta > 3600:
                        eta_format = ('%d:%02d:%02d' %
                                      (eta // 3600,
                                       (eta % 3600) // 60, eta % 60))
                    elif eta > 60:
                        eta_format = '%d:%02d' % (eta // 60, eta % 60)
                    else:
                        eta_format = '%ds' % eta
                    train_logs['loss'] = tr_loss.avg
                    train_logs['g_acc'] = tr_g_acc.avg
                    train_logs['d_acc'] = tr_d_acc.avg
                    train_logs['g_loss'] = tr_g_loss.avg
                    train_logs['d_loss'] = tr_d_loss.avg
                    show_info = f'[Training]:[{epoch}/{args.epochs}]{global_step}/{num_train_optimization_steps} ' \
                                f'- ETA: {eta_format}' + "-".join(
                        [f' {key}: {value:.4f} ' for key, value in train_logs.items()])
                    logger.info(show_info)
                    tr_g_acc.reset()
                    tr_d_acc.reset()
                    tr_loss.reset()
                    tr_g_loss.reset()
                    tr_d_loss.reset()
                    start_time = now

                if global_step % args.num_save_steps == 0:
                    if args.local_rank in [-1, 0] and args.num_save_steps > 0:
                        # Save model checkpoint
                        output_dir = args.output_dir / f'lm-checkpoint-{global_step}'
                        if not output_dir.exists():
                            output_dir.mkdir()
                        # save model
                        model_to_save = model.module if hasattr(
                            model, 'module'
                        ) else model  # Take care of distributed/parallel training
                        model_to_save.save_pretrained(str(output_dir))
                        torch.save(args, str(output_dir / 'training_args.bin'))
                        logger.info("Saving model checkpoint to %s",
                                    output_dir)

                        model.module.generator.save_pretrained(
                            str(output_dir / "G"))
                        logger.info("Saving generator model checkpoint to %s",
                                    output_dir / "G")
                        model.module.electra.save_pretrained(
                            str(output_dir / "D"))
                        logger.info("Saving electra model checkpoint to %s",
                                    output_dir / "D")

                        torch.save(optimizer.state_dict(),
                                   str(output_dir / "optimizer.bin"))

                        # save config
                        output_config_file = output_dir / CONFIG_NAME
                        output_config_file_D = output_dir / "D" / CONFIG_NAME
                        output_config_file_G = output_dir / "G" / CONFIG_NAME

                        with open(str(output_config_file), 'w') as f:
                            f.write(model_to_save.config.to_json_string())
                        with open(str(output_config_file_D), 'w') as f:
                            f.write(
                                model.module.electra.config.to_json_string())
                        with open(str(output_config_file_G), 'w') as f:
                            f.write(
                                model.module.generator.config.to_json_string())
                        # save vocab
                        tokenizer.save_vocabulary(output_dir)
Example #3
0
def take_train_steps(args, model, tokenizer, train_dataloader, prune):
    if args.max_steps > 0:
        num_training_steps = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        num_training_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
    args.warmup_steps = int(num_training_steps * args.warmup_proportion)
    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

    optimizer = AdamW(params=optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps,
                                                num_training_steps=num_training_steps)
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                args.train_batch_size * args.gradient_accumulation_steps * (
                    torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", num_training_steps)

    for epoch in range(int(args.num_train_epochs)):
        pbar = ProgressBar(n_total=len(train_dataloader), desc='Training')
        model.train()
        prune.on_epoch_begin(epoch)
        for step, batch in enumerate(train_dataloader):
            prune.on_batch_begin(step)
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1],
                      'labels': batch[3]}
            #inputs['token_type_ids'] = batch[2]
            outputs = model(**inputs)
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
    
            prune.on_batch_end()
            if step >= 20:
                break;
            if args.local_rank in [-1, 0] and args.logging_steps > 0 and step % args.logging_steps == 20:
                # Log metrics
                if args.local_rank == -1:  # Only evaluate when single GPU otherwise metrics may not average well
                    evaluate(args, model, tokenizer)

            #if args.local_rank in [-1, 0] and args.save_steps > 0 and step % args.save_steps  == 20:
            #    # Save model checkpoint
            #    output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(step + epoch * args.save_steps))
            #    if not os.path.exists(output_dir):
            #        os.makedirs(output_dir)
            #    model_to_save = model.module if hasattr(model,
            #                                            'module') else model  # Take care of distributed/parallel training
            #    model_to_save.save_pretrained(output_dir)
            #    torch.save(args, os.path.join(output_dir, 'training_args.bin'))
            #    torch.save(model, os.path.join(output_dir, 'model.bin'))
            #    logger.info("Saving model checkpoint to %s", output_dir)
            pbar(step, {'loss': loss.item()})
        prune.on_epoch_end()
        print(" ")
        if 'cuda' in str(args.device):
            torch.cuda.empty_cache()
Example #4
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate_fn)

    if args.max_steps > 0:
        num_training_steps = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        num_training_steps = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs
    args.warmup_steps = int(num_training_steps * args.warmup_proportion)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(params=optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=num_training_steps)

    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", num_training_steps)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    seed_everything(
        args.seed
    )  # Added here for reproductibility (even between python 2 and 3)
    for _ in range(int(args.num_train_epochs)):
        pbar = ProgressBar(n_total=len(train_dataloader), desc='Training')
        for step, batch in enumerate(train_dataloader):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'input_ids': batch[0],
                'attention_mask': batch[1],
                'labels': batch[3]
            }
            inputs['token_type_ids'] = batch[2]
            outputs = model(**inputs)
            loss = outputs[0]

            if args.n_gpu > 1:
                loss = loss.mean()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args.max_grad_norm)

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1

            if args.local_rank in [
                    -1, 0
            ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                # Log metrics
                if args.local_rank == -1:
                    evaluate(args, model, tokenizer)

            # if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
            #     output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
            #     if not os.path.exists(output_dir):
            #         os.makedirs(output_dir)
            #     model_to_save = model.module if hasattr(model,
            #                                             'module') else model
            #     model_to_save.save_pretrained(output_dir)
            #     torch.save(args, os.path.join(output_dir, 'training_args.bin'))
            #     logger.info("Saving model checkpoint to %s", output_dir)
            pbar(step, {'loss': loss.item()})

        print(" ")
        if 'cuda' in str(args.device):
            torch.cuda.empty_cache()

    # if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
    output_dir = os.path.join(args.output_dir,
                              'checkpoint-{}'.format(global_step))
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    model_to_save = model.module if hasattr(model, 'module') else model
    model_to_save.save_pretrained(output_dir)
    torch.save(args, os.path.join(output_dir, 'training_args.bin'))
    logger.info("Saving model checkpoint to %s", output_dir)

    return global_step, tr_loss / global_step
 for name, param in param_optimizer:
     print(f" param size {name}-->{param.size()} ")
 no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
 optimizer_grouped_parameters = [{
     'params':
     [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
     'weight_decay':
     args.weight_decay
 }, {
     'params':
     [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
     'weight_decay':
     0.0
 }]
 optimizer = AdamW(params=optimizer_grouped_parameters,
                   lr=args.learning_rate,
                   eps=args.adam_epsilon)
 scheduler = get_linear_schedule_with_warmup(
     optimizer,
     num_warmup_steps=args.warmup_steps,
     num_training_steps=num_train_optimization_steps)
 # optimizer = Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
 if args.model_path:
     optimizer.load_state_dict(
         torch.load(args.model_path + "/optimizer.bin"))
 if args.fp16:
     try:
         from apex import amp
     except ImportError:
         raise ImportError(
             "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."