示例#1
0
def main():

    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--input_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain .hdf5 files  for the task.")
    parser.add_argument("--config_file",
                        default="bert_config.json",
                        type=str,
                        required=False,
                        help="The BERT model config")
    ckpt_group = parser.add_mutually_exclusive_group(required=True)
    ckpt_group.add_argument("--ckpt_dir",
                            default=None,
                            type=str,
                            help="The ckpt directory, e.g. /results")
    ckpt_group.add_argument("--ckpt_path",
                            default=None,
                            type=str,
                            help="Path to the specific checkpoint")

    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument('--eval', dest='do_eval', action='store_true')
    group.add_argument('--prediction', dest='do_eval', action='store_false')
    ## Other parameters
    parser.add_argument(
        "--bert_model",
        default="bert-large-uncased",
        type=str,
        required=False,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument(
        "--max_predictions_per_seq",
        default=80,
        type=int,
        help="The maximum total of masked tokens in input sequence")
    parser.add_argument("--ckpt_step",
                        default=-1,
                        type=int,
                        required=False,
                        help="The model checkpoint iteration, e.g. 1000")

    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help=
        "Total number of eval  steps to perform, otherwise use full dataset")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--fp16',
        default=False,
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument("--log_path",
                        help="Out file for DLLogger",
                        default="/workspace/dllogger_inference.out",
                        type=str)

    args = parser.parse_args()

    if 'LOCAL_RANK' in os.environ:
        args.local_rank = int(os.environ['LOCAL_RANK'])

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")

    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    if is_main_process():
        dllogger.init(backends=[
            dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
                                       filename=args.log_path),
            dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE,
                                   step_format=format_step)
        ])
    else:
        dllogger.init(backends=[])

    n_gpu = torch.cuda.device_count()
    if n_gpu > 1:
        assert (args.local_rank != -1
                )  # only use torch.distributed for multi-gpu

    dllogger.log(
        step=
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16),
        data={})

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # Prepare model
    config = BertConfig.from_json_file(args.config_file)
    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)
    model = BertForPreTraining(config)

    if args.ckpt_dir:
        if args.ckpt_step == -1:
            #retrieve latest model
            model_names = [
                f for f in os.listdir(args.ckpt_dir) if f.endswith(".pt")
            ]
            args.ckpt_step = max([
                int(x.split('.pt')[0].split('_')[1].strip())
                for x in model_names
            ])
            dllogger.log(step="load model saved at iteration",
                         data={"number": args.ckpt_step})
        model_file = os.path.join(args.ckpt_dir,
                                  "ckpt_" + str(args.ckpt_step) + ".pt")
    else:
        model_file = args.ckpt_path
    state_dict = torch.load(model_file, map_location="cpu")["model"]
    model.load_state_dict(state_dict, strict=False)

    if args.fp16:
        model.half(
        )  # all parameters and buffers are converted to half precision
    model.to(device)

    multi_gpu_training = args.local_rank != -1 and torch.distributed.is_initialized(
    )
    if multi_gpu_training:
        model = DDP(model)

    files = [
        os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
        if os.path.isfile(os.path.join(args.input_dir, f)) and 'test' in f
    ]
    files.sort()

    dllogger.log(step="***** Running Inference *****", data={})
    dllogger.log(step="  Inference batch", data={"size": args.eval_batch_size})

    model.eval()

    nb_instances = 0
    max_steps = args.max_steps if args.max_steps > 0 else np.inf
    global_step = 0
    total_samples = 0

    begin_infer = time.time()
    with torch.no_grad():
        if args.do_eval:
            final_loss = 0.0  #
            for data_file in files:
                dllogger.log(step="Opening ", data={"file": data_file})
                dataset = pretraining_dataset(
                    input_file=data_file,
                    max_pred_length=args.max_predictions_per_seq)
                if not multi_gpu_training:
                    train_sampler = RandomSampler(dataset)
                    datasetloader = DataLoader(dataset,
                                               sampler=train_sampler,
                                               batch_size=args.eval_batch_size,
                                               num_workers=4,
                                               pin_memory=True)
                else:
                    train_sampler = DistributedSampler(dataset)
                    datasetloader = DataLoader(dataset,
                                               sampler=train_sampler,
                                               batch_size=args.eval_batch_size,
                                               num_workers=4,
                                               pin_memory=True)
                for step, batch in enumerate(
                        tqdm(datasetloader, desc="Iteration")):
                    if global_step > max_steps:
                        break
                    batch = [t.to(device) for t in batch]
                    input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch  #\
                    loss = model(input_ids=input_ids,
                                 token_type_ids=segment_ids,
                                 attention_mask=input_mask,
                                 masked_lm_labels=masked_lm_labels,
                                 next_sentence_label=next_sentence_labels)
                    final_loss += loss.item()

                    global_step += 1

                total_samples += len(datasetloader)
                torch.cuda.empty_cache()
                if global_step > max_steps:
                    break
            final_loss /= global_step
            if multi_gpu_training:
                final_loss = torch.tensor(final_loss, device=device)
                dist.all_reduce(final_loss)
                final_loss /= torch.distributed.get_world_size()
            if (not multi_gpu_training or
                (multi_gpu_training and torch.distributed.get_rank() == 0)):
                dllogger.log(step="Inference Loss",
                             data={"final_loss": final_loss.item()})

        else:  # inference
            # if multi_gpu_training:
            #     torch.distributed.barrier()
            # start_t0 = time.time()
            for data_file in files:
                dllogger.log(step="Opening ", data={"file": data_file})
                dataset = pretraining_dataset(
                    input_file=data_file,
                    max_pred_length=args.max_predictions_per_seq)
                if not multi_gpu_training:
                    train_sampler = RandomSampler(dataset)
                    datasetloader = DataLoader(dataset,
                                               sampler=train_sampler,
                                               batch_size=args.eval_batch_size,
                                               num_workers=4,
                                               pin_memory=True)
                else:
                    train_sampler = DistributedSampler(dataset)
                    datasetloader = DataLoader(dataset,
                                               sampler=train_sampler,
                                               batch_size=args.eval_batch_size,
                                               num_workers=4,
                                               pin_memory=True)

                for step, batch in enumerate(
                        tqdm(datasetloader, desc="Iteration")):
                    if global_step > max_steps:
                        break

                    batch = [t.to(device) for t in batch]
                    input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch  #\

                    lm_logits, nsp_logits = model(input_ids=input_ids,
                                                  token_type_ids=segment_ids,
                                                  attention_mask=input_mask,
                                                  masked_lm_labels=None,
                                                  next_sentence_label=None)

                    nb_instances += input_ids.size(0)
                    global_step += 1

                total_samples += len(datasetloader)
                torch.cuda.empty_cache()
                if global_step > max_steps:
                    break
            # if multi_gpu_training:
            #     torch.distributed.barrier()
            if (not multi_gpu_training or
                (multi_gpu_training and torch.distributed.get_rank() == 0)):
                dllogger.log(step="Done Inferring on samples", data={})

    end_infer = time.time()
    dllogger.log(step="Inference perf",
                 data={
                     "inference_sequences_per_second":
                     total_samples * args.eval_batch_size /
                     (end_infer - begin_infer)
                 })
示例#2
0
def main():

    print("IN NEW MAIN XD\n")
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--input_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain .hdf5 files  for the task.")
    parser.add_argument("--config_file",
                        default="bert_config.json",
                        type=str,
                        required=False,
                        help="The BERT model config")
    parser.add_argument("--ckpt_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The ckpt directory, e.g. /results")

    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument('--eval', dest='do_eval', action='store_true')
    group.add_argument('--prediction', dest='do_eval', action='store_false')
    ## Other parameters
    parser.add_argument(
        "--bert_model",
        default="bert-large-uncased",
        type=str,
        required=False,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument(
        "--max_predictions_per_seq",
        default=80,
        type=int,
        help="The maximum total of masked tokens in input sequence")
    parser.add_argument("--ckpt_step",
                        default=-1,
                        type=int,
                        required=False,
                        help="The model checkpoint iteration, e.g. 1000")

    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help=
        "Total number of eval  steps to perform, otherwise use full dataset")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--fp16',
        default=False,
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")

    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
    n_gpu = torch.cuda.device_count()
    if n_gpu > 1:
        assert (args.local_rank != -1
                )  # only use torch.distributed for multi-gpu
    logger.info("device %s n_gpu %d distributed inference %r", device, n_gpu,
                bool(args.local_rank != -1))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # Prepare model
    config = BertConfig.from_json_file(args.config_file)
    model = BertForPreTraining(config)

    if args.ckpt_step == -1:
        #retrieve latest model
        model_names = [
            f for f in os.listdir(args.ckpt_dir) if f.endswith(".model")
        ]
        args.ckpt_step = max([
            int(x.split('.model')[0].split('_')[1].strip())
            for x in model_names
        ])
        print("load model saved at iteraton", args.ckpt_step)
    model_file = os.path.join(args.ckpt_dir,
                              "ckpt_" + str(args.ckpt_step) + ".model")
    state_dict = torch.load(model_file, map_location="cpu")
    model.load_state_dict(state_dict, strict=False)

    if args.fp16:
        model.half(
        )  # all parameters and buffers are converted to half precision
    model.to(device)

    multi_gpu_training = args.local_rank != -1 and torch.distributed.is_initialized(
    )
    if multi_gpu_training:
        model = DDP(model)

    files = [
        os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir)
        if os.path.isfile(os.path.join(args.input_dir, f))
    ]
    files.sort()

    logger.info("***** Running evaluation *****")
    logger.info("  Batch size = %d", args.eval_batch_size)

    model.eval()
    print("Evaluation. . .")

    nb_instances = 0
    max_steps = args.max_steps if args.max_steps > 0 else np.inf
    global_step = 0

    with torch.no_grad():
        if args.do_eval:
            final_loss = 0.0  #
            for data_file in files:
                logger.info("file %s" % (data_file))
                dataset = pretraining_dataset(
                    input_file=data_file,
                    max_pred_length=args.max_predictions_per_seq)
                if not multi_gpu_training:
                    train_sampler = RandomSampler(dataset)
                    datasetloader = DataLoader(dataset,
                                               sampler=train_sampler,
                                               batch_size=args.eval_batch_size,
                                               num_workers=4,
                                               pin_memory=True)
                else:
                    train_sampler = DistributedSampler(dataset)
                    datasetloader = DataLoader(dataset,
                                               sampler=train_sampler,
                                               batch_size=args.eval_batch_size,
                                               num_workers=4,
                                               pin_memory=True)
                for step, batch in enumerate(
                        tqdm(datasetloader, desc="Iteration")):
                    if global_step > max_steps:
                        break

                    batch = [t.to(device) for t in batch]
                    input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch  #\
                    loss = model(input_ids=input_ids,
                                 token_type_ids=segment_ids,
                                 attention_mask=input_mask,
                                 masked_lm_labels=masked_lm_labels,
                                 next_sentence_label=next_sentence_labels)
                    final_loss += loss

                    global_step += 1

                torch.cuda.empty_cache()
                if global_step > max_steps:
                    break
            final_loss /= global_step
            if multi_gpu_training:
                final_loss /= torch.distributed.get_world_size()
                dist.all_reduce(final_loss)
            if (not multi_gpu_training or
                (multi_gpu_training and torch.distributed.get_rank() == 0)):
                logger.info("Finished: Final Loss = {}".format(final_loss))

        else:  # inference
            # if multi_gpu_training:
            #     torch.distributed.barrier()
            # start_t0 = time.time()
            for data_file in files:
                logger.info("file %s" % (data_file))
                dataset = pretraining_dataset(
                    input_file=data_file,
                    max_pred_length=args.max_predictions_per_seq)
                if not multi_gpu_training:
                    train_sampler = RandomSampler(dataset)
                    datasetloader = DataLoader(dataset,
                                               sampler=train_sampler,
                                               batch_size=args.eval_batch_size,
                                               num_workers=4,
                                               pin_memory=True)
                else:
                    train_sampler = DistributedSampler(dataset)
                    datasetloader = DataLoader(dataset,
                                               sampler=train_sampler,
                                               batch_size=args.eval_batch_size,
                                               num_workers=4,
                                               pin_memory=True)
                for step, batch in enumerate(
                        tqdm(datasetloader, desc="Iteration")):
                    if global_step > max_steps:
                        break

                    batch = [t.to(device) for t in batch]
                    input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch  #\

                    lm_logits, nsp_logits = model(input_ids=input_ids,
                                                  token_type_ids=segment_ids,
                                                  attention_mask=input_mask,
                                                  masked_lm_labels=None,
                                                  next_sentence_label=None)

                    nb_instances += input_ids.size(0)

                    global_step += 1
                torch.cuda.empty_cache()
                if global_step > max_steps:
                    break
            # if multi_gpu_training:
            #     torch.distributed.barrier()
            if (not multi_gpu_training or
                (multi_gpu_training and torch.distributed.get_rank() == 0)):
                logger.info("Finished")
示例#3
0
def inference(args):
    start_t = time.time()
    bert_module = BertForPreTraining(
        args.vocab_size,
        args.seq_length,
        args.hidden_size,
        args.num_hidden_layers,
        args.num_attention_heads,
        args.intermediate_size,
        nn.GELU(),
        args.hidden_dropout_prob,
        args.attention_probs_dropout_prob,
        args.max_position_embeddings,
        args.type_vocab_size,
        args.vocab_size,
    )
    end_t = time.time()
    print("Initialize model using time: {:.3f}s".format(end_t - start_t))

    start_t = time.time()
    if args.use_lazy_model:
        from utils.compare_lazy_outputs import load_params_from_lazy

        load_params_from_lazy(
            bert_module.state_dict(),
            args.model_path,
        )
    else:
        bert_module.load_state_dict(flow.load(args.model_path))
    end_t = time.time()
    print("Loading parameters using time: {:.3f}s".format(end_t - start_t))

    bert_module.eval()
    bert_module.to(args.device)

    class BertEvalGraph(nn.Graph):
        def __init__(self):
            super().__init__()
            self.bert = bert_module

        def build(self, input_ids, input_masks, segment_ids):
            input_ids = input_ids.to(device=args.device)
            input_masks = input_masks.to(device=args.device)
            segment_ids = segment_ids.to(device=args.device)

            with flow.no_grad():
                # 1. forward the next_sentence_prediction and masked_lm model
                _, seq_relationship_scores = self.bert(input_ids, input_masks,
                                                       segment_ids)

            return seq_relationship_scores

    bert_eval_graph = BertEvalGraph()

    start_t = time.time()
    inputs = [np.random.randint(0, 20, size=args.seq_length)]
    inputs = flow.Tensor(inputs,
                         dtype=flow.int64,
                         device=flow.device(args.device))
    mask = flow.cast(inputs > 0, dtype=flow.int64)

    segment_info = flow.zeros_like(inputs)
    prediction = bert_eval_graph(inputs, mask, segment_info)
    print(prediction.numpy())
    end_t = time.time()
    print("Inference using time: {:.3f}".format(end_t - start_t))
def main():

    args = get_config()

    if args.with_cuda:
        device = flow.device("cuda")
    else:
        device = flow.device("cpu")

    print("Creating Dataloader")
    train_data_loader = OfRecordDataLoader(
        ofrecord_dir=args.ofrecord_path,
        mode="train",
        dataset_size=args.train_dataset_size,
        batch_size=args.train_batch_size,
        data_part_num=args.train_data_part,
        seq_length=args.seq_length,
        max_predictions_per_seq=args.max_predictions_per_seq,
        consistent=False,
    )

    test_data_loader = OfRecordDataLoader(
        ofrecord_dir=args.ofrecord_path,
        mode="test",
        dataset_size=1024,
        batch_size=args.val_batch_size,
        data_part_num=4,
        seq_length=args.seq_length,
        max_predictions_per_seq=args.max_predictions_per_seq,
        consistent=False,
    )

    print("Building BERT Model")
    hidden_size = 64 * args.num_attention_heads
    intermediate_size = 4 * hidden_size
    bert_model = BertForPreTraining(
        args.vocab_size,
        args.seq_length,
        hidden_size,
        args.num_hidden_layers,
        args.num_attention_heads,
        intermediate_size,
        nn.GELU(),
        args.hidden_dropout_prob,
        args.attention_probs_dropout_prob,
        args.max_position_embeddings,
        args.type_vocab_size,
    )

    # Load the same initial parameters with lazy model.
    # from utils.compare_lazy_outputs import load_params_from_lazy
    # load_params_from_lazy(
    #     bert_model.state_dict(),
    #     "../../OneFlow-Benchmark/LanguageModeling/BERT/initial_model",
    # )

    bert_model = bert_model.to(device)
    if args.use_ddp:
        bert_model = ddp(bert_model)

    optimizer = build_optimizer(
        args.optim_name,
        bert_model,
        args.lr,
        args.weight_decay,
        weight_decay_excludes=["bias", "LayerNorm", "layer_norm"],
        clip_grad_max_norm=1,
        clip_grad_norm_type=2.0,
    )

    steps = args.epochs * len(train_data_loader)
    warmup_steps = int(steps * args.warmup_proportion)

    lr_scheduler = PolynomialLR(optimizer, steps=steps, end_learning_rate=0.0)

    lr_scheduler = flow.optim.lr_scheduler.WarmUpLR(lr_scheduler,
                                                    warmup_factor=0,
                                                    warmup_iters=warmup_steps,
                                                    warmup_method="linear")

    ns_criterion = nn.CrossEntropyLoss(reduction="mean")
    mlm_criterion = nn.CrossEntropyLoss(reduction="none")

    def get_masked_lm_loss(
        logit_blob,
        masked_lm_positions,
        masked_lm_labels,
        label_weights,
        max_prediction_per_seq,
    ):
        # gather valid position indices
        logit_blob = flow.gather(
            logit_blob,
            index=masked_lm_positions.unsqueeze(2).repeat(
                1, 1, args.vocab_size),
            dim=1,
        )

        logit_blob = flow.reshape(logit_blob, [-1, args.vocab_size])
        label_id_blob = flow.reshape(masked_lm_labels, [-1])

        # The `positions` tensor might be zero-padded (if the sequence is too
        # short to have the maximum number of predictions). The `label_weights`
        # tensor has a value of 1.0 for every real prediction and 0.0 for the
        # padding predictions.
        pre_example_loss = mlm_criterion(logit_blob, label_id_blob)
        pre_example_loss = flow.reshape(pre_example_loss,
                                        [-1, max_prediction_per_seq])
        numerator = flow.sum(pre_example_loss * label_weights)
        denominator = flow.sum(label_weights) + 1e-5
        loss = numerator / denominator
        return loss

    train_total_losses = []
    for epoch in range(args.epochs):
        metric = Metric(
            desc="bert pretrain",
            print_steps=args.loss_print_every_n_iters,
            batch_size=args.train_batch_size,
            keys=["total_loss", "mlm_loss", "nsp_loss", "pred_acc"],
        )

        # Train
        bert_model.train()

        for step in range(len(train_data_loader)):
            bert_outputs = pretrain(
                train_data_loader,
                bert_model,
                ns_criterion,
                partial(
                    get_masked_lm_loss,
                    max_prediction_per_seq=args.max_predictions_per_seq,
                ),
                optimizer,
                lr_scheduler,
            )

            if flow.env.get_rank() == 0:
                metric.metric_cb(step, epoch=epoch)(bert_outputs)

            train_total_losses.append(bert_outputs["total_loss"])

        # Eval
        bert_model.eval()
        val_acc = validation(epoch, test_data_loader, bert_model,
                             args.val_print_every_n_iters)

        save_model(bert_model, args.checkpoint_path, epoch, val_acc, False)
示例#5
0
def main():
    args = get_config()

    world_size = flow.env.get_world_size()
    if args.train_global_batch_size is None:
        args.train_global_batch_size = args.train_batch_size * world_size
    else:
        assert args.train_global_batch_size % args.train_batch_size == 0

    if args.val_global_batch_size is None:
        args.val_global_batch_size = args.val_batch_size * world_size
    else:
        assert args.val_global_batch_size % args.val_batch_size == 0

    flow.boxing.nccl.set_fusion_threshold_mbytes(args.nccl_fusion_threshold_mb)
    flow.boxing.nccl.set_fusion_max_ops_num(args.nccl_fusion_max_ops)

    if args.with_cuda:
        device = "cuda"
    else:
        device = "cpu"

    print("Device is: ", device)

    print("Creating Dataloader")
    train_data_loader = OfRecordDataLoader(
        ofrecord_dir=args.ofrecord_path,
        mode="train",
        dataset_size=args.train_dataset_size,
        batch_size=args.train_global_batch_size,
        data_part_num=args.train_data_part,
        seq_length=args.seq_length,
        max_predictions_per_seq=args.max_predictions_per_seq,
        consistent=args.use_consistent,
    )

    test_data_loader = OfRecordDataLoader(
        ofrecord_dir=args.ofrecord_path,
        mode="test",
        dataset_size=1024,
        batch_size=args.val_global_batch_size,
        data_part_num=4,
        seq_length=args.seq_length,
        max_predictions_per_seq=args.max_predictions_per_seq,
        consistent=args.use_consistent,
    )

    print("Building BERT Model")
    hidden_size = 64 * args.num_attention_heads
    intermediate_size = 4 * hidden_size
    bert_model = BertForPreTraining(
        args.vocab_size,
        args.seq_length,
        hidden_size,
        args.num_hidden_layers,
        args.num_attention_heads,
        intermediate_size,
        nn.GELU(),
        args.hidden_dropout_prob,
        args.attention_probs_dropout_prob,
        args.max_position_embeddings,
        args.type_vocab_size,
    )

    # Load the same initial parameters with lazy model.
    # from utils.compare_lazy_outputs import load_params_from_lazy
    # load_params_from_lazy(
    #     bert_model.state_dict(),
    #     "../../OneFlow-Benchmark/LanguageModeling/BERT/initial_model",
    # )

    assert id(bert_model.cls.predictions.decoder.weight) == id(
        bert_model.bert.embeddings.word_embeddings.weight
    )

    ns_criterion = nn.CrossEntropyLoss(reduction="mean")
    mlm_criterion = nn.CrossEntropyLoss(reduction="none")

    if args.use_consistent:
        placement = flow.env.all_device_placement("cuda")
        bert_model = bert_model.to_consistent(
            placement=placement, sbp=flow.sbp.broadcast
        )
    else:
        bert_model.to(device)
        ns_criterion.to(device)
        mlm_criterion.to(device)

    optimizer = build_optimizer(
        args.optim_name,
        bert_model,
        args.lr,
        args.weight_decay,
        weight_decay_excludes=["bias", "LayerNorm", "layer_norm"],
        clip_grad_max_norm=1,
        clip_grad_norm_type=2.0,
    )

    steps = args.epochs * len(train_data_loader)
    warmup_steps = int(steps * args.warmup_proportion)

    lr_scheduler = PolynomialLR(optimizer, steps=steps, end_learning_rate=0.0)

    lr_scheduler = flow.optim.lr_scheduler.WarmUpLR(
        lr_scheduler, warmup_factor=0, warmup_iters=warmup_steps, warmup_method="linear"
    )

    def get_masked_lm_loss(
        logit, masked_lm_labels, label_weights, max_predictions_per_seq,
    ):

        label_id = flow.reshape(masked_lm_labels, [-1])

        # The `positions` tensor might be zero-padded (if the sequence is too
        # short to have the maximum number of predictions). The `label_weights`
        # tensor has a value of 1.0 for every real prediction and 0.0 for the
        # padding predictions.
        pre_example_loss = mlm_criterion(logit, label_id)
        pre_example_loss = flow.reshape(pre_example_loss, [-1, max_predictions_per_seq])
        numerator = flow.sum(pre_example_loss * label_weights)
        denominator = flow.sum(label_weights) + 1e-5
        loss = numerator / denominator
        return loss

    class BertGraph(nn.Graph):
        def __init__(self):
            super().__init__()
            self.bert = bert_model
            self.ns_criterion = ns_criterion
            self.masked_lm_criterion = partial(
                get_masked_lm_loss, max_predictions_per_seq=args.max_predictions_per_seq
            )
            self.add_optimizer(optimizer, lr_sch=lr_scheduler)
            self._train_data_loader = train_data_loader
            if args.grad_acc_steps > 1:
                self.config.set_gradient_accumulation_steps(args.grad_acc_steps)
            if args.use_fp16:
                self.config.enable_amp(True)
                grad_scaler = flow.amp.GradScaler(
                    init_scale=2 ** 30,
                    growth_factor=2.0,
                    backoff_factor=0.5,
                    growth_interval=2000,
                )
                self.set_grad_scaler(grad_scaler)
            self.config.allow_fuse_add_to_output(True)
            self.config.allow_fuse_model_update_ops(True)

        def build(self):

            (
                input_ids,
                next_sentence_labels,
                input_mask,
                segment_ids,
                masked_lm_ids,
                masked_lm_positions,
                masked_lm_weights,
            ) = self._train_data_loader()
            input_ids = input_ids.to(device=device)
            input_mask = input_mask.to(device=device)
            segment_ids = segment_ids.to(device=device)
            next_sentence_labels = next_sentence_labels.to(device=device)
            masked_lm_ids = masked_lm_ids.to(device=device)
            masked_lm_positions = masked_lm_positions.to(device=device)
            masked_lm_weights = masked_lm_weights.to(device=device)

            # 1. forward the next_sentence_prediction and masked_lm model
            prediction_scores, seq_relationship_scores = self.bert(
                input_ids, segment_ids, input_mask, masked_lm_positions
            )

            # 2-1. loss of is_next classification result
            next_sentence_loss = self.ns_criterion(
                seq_relationship_scores.reshape(-1, 2), next_sentence_labels.reshape(-1)
            )

            masked_lm_loss = self.masked_lm_criterion(
                prediction_scores, masked_lm_ids, masked_lm_weights
            )

            total_loss = masked_lm_loss + next_sentence_loss

            total_loss.backward()
            return (
                seq_relationship_scores,
                next_sentence_labels,
                total_loss,
                masked_lm_loss,
                next_sentence_loss,
            )

    bert_graph = BertGraph()

    class BertEvalGraph(nn.Graph):
        def __init__(self):
            super().__init__()
            self.bert = bert_model
            self._test_data_loader = test_data_loader
            self.config.allow_fuse_add_to_output(True)

        def build(self):
            (
                input_ids,
                next_sent_labels,
                input_masks,
                segment_ids,
                masked_lm_ids,
                masked_lm_positions,
                masked_lm_weights,
            ) = self._test_data_loader()
            input_ids = input_ids.to(device=device)
            input_masks = input_masks.to(device=device)
            segment_ids = segment_ids.to(device=device)
            next_sent_labels = next_sent_labels.to(device=device)
            masked_lm_ids = masked_lm_ids.to(device=device)
            masked_lm_positions = masked_lm_positions.to(device)

            with flow.no_grad():
                # 1. forward the next_sentence_prediction and masked_lm model
                _, seq_relationship_scores = self.bert(
                    input_ids, input_masks, segment_ids
                )

            return seq_relationship_scores, next_sent_labels

    bert_eval_graph = BertEvalGraph()

    train_total_losses = []

    for epoch in range(args.epochs):
        metric = Metric(
            desc="bert pretrain",
            print_steps=args.loss_print_every_n_iters,
            batch_size=args.train_global_batch_size * args.grad_acc_steps,
            keys=["total_loss", "mlm_loss", "nsp_loss", "pred_acc"],
        )

        # Train
        bert_model.train()

        for step in range(len(train_data_loader)):
            bert_outputs = pretrain(bert_graph, args.metric_local)

            if flow.env.get_rank() == 0:
                metric.metric_cb(step, epoch=epoch)(bert_outputs)

            train_total_losses.append(bert_outputs["total_loss"])

    # Eval
    bert_model.eval()
    val_acc = validation(
        epoch,
        len(test_data_loader),
        bert_eval_graph,
        args.val_print_every_n_iters,
        args.metric_local,
    )

    save_model(bert_model, args.checkpoint_path, epoch, val_acc, args.use_consistent)