Example #1
0
def main(args):

    logging = config.get_logging(args.log_name)
    logging.info(args)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    tokenizer = BertTokenizer.build_tokenizer(args)
    # train_data_iter = MSmarco_iterator(args, tokenizer, batch_size=args.train_batch_size, world_size=n_gpu, accumulation_steps=args.gradient_accumulation_steps, name="msmarco_train.pk")
    dev_data_iter = MSmarco_iterator(args, tokenizer, batch_size=args.valid_batch_size, world_size=n_gpu, name="msmarco_dev.pk")

    logging.info("| dev batch data size {}".format(len(dev_data_iter)))


    # num_train_steps = (96032//2//2)+(data_size-96032)//2
    missing_keys = []
    unexpected_keys = []
    error_msgs = []

    pre_dir = args.pre_dir
    config_file = os.path.join(pre_dir, CONFIG_NAME)
    bert_config = BertConfig.from_json_file(config_file)
    model = MSmarco(bert_config)
    logging.info("| load model from {}".format(args.path))
    state_dict = torch.load(args.path, map_location=torch.device('cpu'))
    metadata = getattr(state_dict, '_metadata', None)
    # state_dict = state_dict.copy()
    # if metadata is not None:
    #     state_dict._metadata = metadata

    def load(module, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        module._load_from_state_dict(
            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')

    load(model, prefix='module.')

    if len(missing_keys) > 0:
        # logger.info("Weights of {} not initialized from pretrained model: {}".format(
        #     model.__class__.__name__, missing_keys))
        print("| Weights of {} not initialized from pretrained model: {}".format(
            model.__class__.__name__, missing_keys))
    if len(unexpected_keys) > 0:
        # logger.info("Weights from pretrained model not used in {}: {}".format(
            # model.__class__.__name__, unexpected_keys))
        print("Weights from pretrained model not used in {}: {}".format(
            model.__class__.__name__, unexpected_keys))

    # model._load_from_state_dict(state_dict, prefix="module.")
    model.to(device)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

        # save_checkpoint(args, model, epochs)
    validation(args, model, dev_data_iter, n_gpu, 0, 0, logging)
Example #2
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                             "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json")
    parser.add_argument("--predict_file", default=None, type=str,
                        help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
    parser.add_argument("--max_seq_length", default=384, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. Sequences "
                             "longer than this will be truncated, and sequences shorter than this will be padded.")
    parser.add_argument("--doc_stride", default=128, type=int,
                        help="When splitting up a long document into chunks, how much stride to take between chunks.")
    parser.add_argument("--max_query_length", default=64, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--do_train", default=False, action='store_true', help="Whether to run training.")
    parser.add_argument("--do_predict", default=False, action='store_true', help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.")
    parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.")
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs", default=3.0, type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion", default=0.1, type=float,
                        help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
                             "of training.")
    parser.add_argument("--n_best_size", default=20, type=int,
                        help="The total number of n-best predictions to generate in the nbest_predictions.json "
                             "output file.")
    parser.add_argument("--max_answer_length", default=30, type=int,
                        help="The maximum length of an answer that can be generated. This is needed because the start "
                             "and end predictions are not conditioned on one another.")
    parser.add_argument("--verbose_logging", default=False, action='store_true',
                        help="If true, all of the warnings related to data processing will be printed. "
                             "A number of warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed', 
                        type=int, 
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--optimize_on_cpu',
                        default=False,
                        action='store_true',
                        help="Whether to perform optimization and keep the optimizer averages on CPU")
    parser.add_argument('--fp16',
                        default=False,
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=128,
                        help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
    parser.add_argument('--do_lower_case',
                        default=False, action='store_true',
                        help='whether case sensitive')
    parser.add_argument('--do-test',
                        default=False, action='store_true',
                        help='if test ,train and dev data will be small')
    parser.add_argument("--pre-dir", type=str,
                        help="where the pretrained checkpoint")


    args = parser.parse_args()
    print(args)
    # local_rand 多节点训练
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
        if args.fp16:
            logger.info("16-bits training currently not supported in distributed training")
            args.fp16 = False # (see https://github.com/pytorch/pytorch/pull/13496)
    # logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits trainiing: {}".format(
    #     device, n_gpu, bool(args.local_rank != -1), args.fp16))
    # gradient_accumulation_steps == freq_update
    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))
    # 缩小了batch
    args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
    
    print("| gpu count : {}".format(n_gpu))
    print("| train batch size in each gpu : {}".format(args.train_batch_size))
    print("| biuid tokenizer and model in : {}".format(args.pre_dir))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_predict:
        raise ValueError("At least one of `do_train` or `do_predict` must be True.")

    if args.do_train:
        if not args.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if args.do_predict:
        if not args.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError("Output directory () already exists and is not empty.")
    os.makedirs(args.output_dir, exist_ok=True)

    tokenizer = BertTokenizer.build_tokenizer(args)

    train_examples = None
    # 一共需要更新多少次
    num_train_steps = None
    if args.do_train:
        # 加载训练的数据
        # 如果测试的话和可以截断这个
        train_examples = read_squad_examples(
            input_file=args.train_file, is_training=True)
        if args.do_test:
            train_examples = train_examples[:1000]
        num_train_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare model
    # model = BertForQuestionAnswering.from_pretrained(args.bert_model,
                # cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
    model = BertForQuestionAnswering.build_model(args)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    if args.fp16:
        param_optimizer = [(n, param.clone().detach().to('cpu').float().requires_grad_()) \
                            for n, param in model.named_parameters()]
    elif args.optimize_on_cpu:
        param_optimizer = [(n, param.clone().detach().to('cpu').requires_grad_()) \
                            for n, param in model.named_parameters()]
    else:
        # 获得所有的参数,包括名字
        param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    # for n,v in param_optimizer:
    #     print("| name is {}\n".format(n))
    # # print(oo)

    # 吧模型的参数分为两个组
    # 第一组是包括 no_decay = ['bias', 'gamma', 'beta'] 关键字的, ---》 'weight_decay_rate': 0.0
    # 第二组是没有包括关键字 ---》 'weight_decay_rate': 0.01
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},

        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
        ]
    # 优化器
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         t_total=num_train_steps)

    global_step = 0
    if args.do_train:
        train_features = convert_examples_to_features(
            examples=train_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length, # token之后的最大长度。default -》 384
            doc_stride=args.doc_stride, # 分块后后每个快的长度 128
            max_query_length=args.max_query_length, # query token之后的最大长度 default -》 64
            is_training=True)
        # logger.info("|  orig train data = %d", len(train_examples))
        # logger.info("|  features train data = %d", len(train_features))
        # logger.info("|  Batch size = %d", args.train_batch_size)
        # logger.info("|  Num steps = %d", num_train_steps)
        print("| train data count {}, batch size {}, num steps {}".format(len(train_features), args.train_batch_size, num_train_steps))
        # 统一的长度,全部被pad到相同的长度
        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        # 一个pad的mask,有效的input为1,mask就为0
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        # 表示句子的顺序,0,1  --- pad == 0
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
        # target
        all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
        all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)

        # 第0维就是数据的index,传入的数据需要保持他们的第一维的size相同
        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                                   all_start_positions, all_end_positions)
        if args.local_rank == -1:
            # go here
            # 返回一个随机的index,这个index是包括所有的数据
            train_sampler = RandomSampler(train_data)
        else:
            ## edit 2 : 搞错了分支了

            # 如果数据是拍过序的话,那么有一些gpu的数据的长度会很小,会导致效率和其他的问题????
            # 由于这个数据集的操作的特殊性,每一条的数据都是相同的。所以没有影响。
            # 返回一个iter,对每一个rank都返回一个片段。[len*rand,len+len*rand]

            # 返回的是一个index
            print("| in distributedSample")
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
                if n_gpu == 1:
                    batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
                input_ids, input_mask, segment_ids, start_positions, end_positions = batch
                loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
                # print("| loss  is {}".format(loss))
                # print("| loss size is {}".format(loss.size()))
                # print(oo)
                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.fp16 and args.loss_scale != 1.0:
                    # rescale loss for fp16 training
                    # see https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
                    loss = loss * args.loss_scale
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16 or args.optimize_on_cpu:
                        if args.fp16 and args.loss_scale != 1.0:
                            # scale down gradients for fp16 training
                            for param in model.parameters():
                                if param.grad is not None:
                                    param.grad.data = param.grad.data / args.loss_scale
                        is_nan = set_optimizer_params_grad(param_optimizer, model.named_parameters(), test_nan=True)
                        if is_nan:
                            logger.info("FP16 TRAINING: Nan in gradients, reducing loss scaling")
                            args.loss_scale = args.loss_scale / 2
                            model.zero_grad()
                            continue
                        optimizer.step()
                        copy_optimizer_params_to_model(model.named_parameters(), param_optimizer)
                    else:
                        optimizer.step()
                    # optimizer.step()
                    model.zero_grad()
                    global_step += 1

    if args.do_predict:
        eval_examples = read_squad_examples(
            input_file=args.predict_file, is_training=False)
        if args.do_test:
            eval_examples = eval_examples[:1000]
        eval_features = convert_examples_to_features(
            examples=eval_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=False)

        # logger.info("| Running predictions *****")
        # logger.info("| orig dev data = %d", len(eval_examples))
        # logger.info("| split dev data = %d", len(eval_features))
        # logger.info("| dev batch = %d", args.predict_batch_size)
        print("\n| dev data count {}, batch size {}".format(len(eval_features), args.predict_batch_size))

        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
        if args.local_rank == -1:

            eval_sampler = SequentialSampler(eval_data)
        else:
            eval_sampler = DistributedSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size)

        model.eval()
        all_results = []
        # logger.info("Start evaluating")
        for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            with torch.no_grad():
                batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask)
            for i, example_index in enumerate(example_indices):
                start_logits = batch_start_logits[i].detach().cpu().tolist()
                end_logits = batch_end_logits[i].detach().cpu().tolist()
                eval_feature = eval_features[example_index.item()]
                unique_id = int(eval_feature.unique_id)
                all_results.append(RawResult(unique_id=unique_id,
                                             start_logits=start_logits,
                                             end_logits=end_logits))
        output_prediction_file = os.path.join(args.output_dir, "predictions.json")
        output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
        write_predictions(eval_examples, eval_features, all_results,
                          args.n_best_size, args.max_answer_length,
                          args.do_lower_case, output_prediction_file,
                          output_nbest_file, args.verbose_logging)
Example #3
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--save-dir",
                        default="checkpoints",
                        type=str,
                        help="path to save checkpoints")

    ## Other parameters
    parser.add_argument("--data",
                        default="data",
                        type=str,
                        help="MSmarco train and dev data")
    parser.add_argument("--origin-data",
                        default="data",
                        type=str,
                        help="MSmarco train and dev data, will be tokenizer")
    parser.add_argument("--path",
                        default="data",
                        type=str,
                        help="path(s) to model file(s), colon separated")
    parser.add_argument("--save",
                        default="checkpoints/MSmarco",
                        type=str,
                        help="path(s) to model file(s), colon separated")
    parser.add_argument("--pre-dir",
                        type=str,
                        help="where the pretrained checkpoint")
    parser.add_argument("--log-name", type=str, help="where logfile")
    parser.add_argument(
        "--max-passage-tokens",
        default=200,
        type=int,
        help=
        "The maximum total input passage length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument(
        "--max-query-tokens",
        default=50,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.")
    parser.add_argument(
        '--gradient-accumulation-steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--train-batch-size",
                        default=2,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--predict-batch-size",
                        default=1,
                        type=int,
                        help="Total batch size for predictions.")
    parser.add_argument("--lr",
                        default=6.25e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num-train-epochs",
                        default=3,
                        type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup-proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
        "of training.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--do-lower-case',
                        default=False,
                        action='store_true',
                        help='whether case sensitive')
    parser.add_argument('--threshold', type=int, default=0.36)
    parser.add_argument('--logfile', type=str, default="logfile.log")
    parser.add_argument('--validate-updates',
                        type=int,
                        default=30000,
                        metavar='N',
                        help='validate every N updates')
    parser.add_argument('--loss-interval',
                        type=int,
                        default=5000,
                        metavar='N',
                        help='validate every N updates')
    args = parser.parse_args()
    # global logger
    # logger = logging.getLogger(args.log_name)
    # logger.error("| f**k logger")

    # first make corpus
    # tokenizer = BertTokenizer.build_tokenizer(args)
    # make_msmarco(args, tokenizer)

    print(args)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    print("| gpu count : {}".format(n_gpu))
    print("| train batch size in each gpu : {}".format(args.train_batch_size))
    print("| biuid tokenizer and model in : {}".format(args.pre_dir))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    tokenizer = BertTokenizer.build_tokenizer(args)
    train_data_iter = MSmarco_iterator(
        args,
        tokenizer,
        batch_size=args.train_batch_size,
        world_size=n_gpu,
        accumulation_steps=args.gradient_accumulation_steps,
        name="msmarco_train.pk")
    dev_data_iter = MSmarco_iterator(args,
                                     tokenizer,
                                     batch_size=args.train_batch_size,
                                     world_size=n_gpu,
                                     name="msmarco_dev.pk")
    gradient_accumulation_steps = args.gradient_accumulation_steps
    data_size = len(train_data_iter)
    num_train_steps = args.num_train_epochs * data_size
    print("| load dataset {}".format(data_size))

    model = ParallelMSmarco.build_model(args)
    cls_criterion = nn.KLDivLoss()
    model.to(device)
    if n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelModel(model)
        cls_criterion = DataParallelCriterion(cls_criterion)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.lr,
                         warmup=args.warmup_proportion,
                         t_total=num_train_steps)

    global_update = 0
    for epochs in range(args.num_train_epochs):
        total_loss = 0
        for step, batch in enumerate(
                tqdm(train_data_iter, desc="Train Iteration")):
            for key in batch.keys():
                batch[key].to(device)
            targets = batch["targets"]
            batch.pop("targets")
            model.train()
            loss_logits = model(**batch)
            # pdb.set_trace()
            loss = cls_criterion(loss_logits, targets)
            if n_gpu > 1:
                loss = loss.sum()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            # loss.backward()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                model.zero_grad()
                global_update += 1
            # print("| loss {}".format(loss.size()))

            # optimizer.step()
            # model.zero_grad()
            # global_update += 1
            if global_update > 0 and global_update % args.validate_updates == 0:
                validation(args, model, cls_criterion, dev_data_iter, n_gpu,
                           epochs, global_update)
            if global_update > 0 and global_update % args.loss_interval == 0:
                logging.info(
                    "TRAIN ::Epoch {} updates {}, train loss {}".format(
                        epochs, global_update, loss.item()))
        save_checkpoint(args, model, epochs)
        validation(args, model, cls_criterion, dev_data_iter, n_gpu, epochs,
                   global_update)
Example #4
0
def main(args):


    logging = config.get_logging(args.log_name)
    logging.info("##"*20)
    logging.info("##"*20)
    logging.info("##"*20)
    logging.info(args)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()

    logging.info("| question first :: {}".format(args.question_first))
    logging.info("| gpu count : {}".format(n_gpu))
    logging.info("| train batch size in each gpu : {}".format(args.train_batch_size))
    logging.info("| biuid tokenizer and model in : {}".format(args.pre_dir))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)


    tokenizer = BertTokenizer.build_tokenizer(args)
    train_data_iter = MSmarco_iterator(args, tokenizer, batch_size=args.train_batch_size, world_size=n_gpu, accumulation_steps=args.gradient_accumulation_steps, name="msmarco_train.pk")
    dev_data_iter = MSmarco_iterator(args, tokenizer, batch_size=args.valid_batch_size, world_size=n_gpu, name="msmarco_dev.pk")
    data_size = len(train_data_iter)
    gradient_accumulation_steps = args.gradient_accumulation_steps
    num_train_steps = args.num_train_epochs*data_size//gradient_accumulation_steps
    # logging.info("| load dataset {}".format(data_size))
    logging.info("| train data size {}".format(len(train_data_iter)*n_gpu*args.train_batch_size))
    logging.info("| dev data size {}".format(len(dev_data_iter)*n_gpu*args.valid_batch_size))
    logging.info("| train batch data size {}".format(len(train_data_iter)))
    logging.info("| dev batch data size {}".format(len(dev_data_iter)))
    logging.info("| update in each train data {}".format(data_size//gradient_accumulation_steps))
    logging.info("| total update {}".format(num_train_steps))

    # num_train_steps = (96032//2//2)+(data_size-96032)//2

    model = MSmarco.build_model(args)
    model.to(device)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta', 'layer_norm']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},

        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
        ]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.lr,
                         warmup=args.warmup_proportion,
                         t_total=num_train_steps)
    logging.info("| init lr is {}".format(optimizer.get_lr()))

    global_update = 0
    for epochs in range(args.num_train_epochs):
        total_loss = 0
        merge_batch = []
        # count = 0
        for step, batch in enumerate(tqdm(train_data_iter, desc="Train Iteration")):
            model.train()
            # if step < 96032:
            #     merge_batch.append(batch)
            #     if len(merge_batch) == 2:
            #         batch = merger_tensor(merge_batch)
            #         merge_batch = []
            #     else:
            #         continue
            if n_gpu==1:
                for key in batch.keys():
                    batch[key]=batch[key].to(device)
            loss = model(**batch)
            # count += 1
            # pdb.set_trace()
            if n_gpu > 1:
                loss = loss.mean()
            if args.gradient_accumulation_steps > 1:
                loss = loss/args.gradient_accumulation_steps
            loss.backward()
            if (step+1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                model.zero_grad()
                global_update += 1
                if global_update % args.validate_updates==0:
                    validation(args, model, dev_data_iter, n_gpu, epochs, global_update, logging)
            if (step+1) % args.loss_interval==0:
                logging.info("TRAIN ::Epoch {} updates {}, train loss {}".format(epochs, global_update, loss.item()))
        # save_checkpoint(args, model, epochs)
        validation(args, model, dev_data_iter, n_gpu, epochs, global_update, logging)