def train(data_train, data_eval, model, nsp_loss, mlm_loss, vocab_size, ctx):
    """Training function."""
    hvd.broadcast_parameters(model.collect_params(), root_rank=0)

    mlm_metric = nlp.metric.MaskedAccuracy()
    nsp_metric = nlp.metric.MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    logging.debug('Creating distributed trainer...')
    lr = args.lr
    optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01}
    if args.dtype == 'float16':
        optim_params['multi_precision'] = True

    dynamic_loss_scale = args.dtype == 'float16'
    if dynamic_loss_scale:
        loss_scale_param = {'scale_window': 2000 / num_workers}
    else:
        loss_scale_param = None
    trainer = hvd.DistributedTrainer(model.collect_params(), 'bertadam', optim_params)
    fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale,
                               loss_scaler_params=loss_scale_param)

    if args.start_step:
        state_path = os.path.join(args.ckpt_dir, '%07d.states.%02d'%(args.start_step, local_rank))
        logging.info('Loading trainer state from %s', state_path)
        nlp.utils.load_states(trainer, state_path)

    accumulate = args.accumulate
    num_train_steps = args.num_steps
    warmup_ratio = args.warmup_ratio
    num_warmup_steps = int(num_train_steps * warmup_ratio)
    params = [p for p in model.collect_params().values() if p.grad_req != 'null']
    param_dict = model.collect_params()

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    if accumulate > 1:
        for p in params:
            p.grad_req = 'add'

    train_begin_time = time.time()
    begin_time = time.time()
    running_mlm_loss, running_nsp_loss = 0, 0
    running_num_tks = 0
    batch_num = 0
    step_num = args.start_step

    logging.debug('Training started')
    while step_num < num_train_steps:
        for _, dataloader in enumerate(data_train):
            if step_num >= num_train_steps:
                break

            # create dummy data loader if needed
            if args.dummy_data_len:
                target_shape = (args.batch_size, args.dummy_data_len)
                dataloader = get_dummy_dataloader(dataloader, target_shape)

            for _, data_batch in enumerate(dataloader):
                if step_num >= num_train_steps:
                    break
                if batch_num % accumulate == 0:
                    step_num += 1
                    # if accumulate > 1, grad_req is set to 'add', and zero_grad is required
                    if accumulate > 1:
                        param_dict.zero_grad()
                    # update learning rate
                    if step_num <= num_warmup_steps:
                        new_lr = lr * step_num / num_warmup_steps
                    else:
                        offset = lr * step_num / num_train_steps
                        new_lr = lr - offset
                    trainer.set_learning_rate(new_lr)
                    if args.profile:
                        profile(step_num, 10, 14, profile_name=args.profile + str(rank))

                # load data
                if args.use_avg_len:
                    data_list = [[seq.as_in_context(context) for seq in shard]
                                 for context, shard in zip([ctx], data_batch)]
                else:
                    data_list = list(split_and_load(data_batch, [ctx]))
                data = data_list[0]

                # forward
                with mx.autograd.record():
                    (ls, ns_label, classified, masked_id, decoded, \
                     masked_weight, ls1, ls2, valid_len) = forward(data, model, mlm_loss,
                                                                   nsp_loss, vocab_size, args.dtype)
                    ls = ls / accumulate
                    # backward
                    if args.dtype == 'float16':
                        fp16_trainer.backward(ls)
                    else:
                        ls.backward()

                running_mlm_loss += ls1.as_in_context(mx.cpu())
                running_nsp_loss += ls2.as_in_context(mx.cpu())
                running_num_tks += valid_len.sum().as_in_context(mx.cpu())

                # update
                if (batch_num + 1) % accumulate == 0:
                    # step() performs 3 things:
                    # 1. allreduce gradients from all workers
                    # 2. checking the global_norm of gradients and clip them if necessary
                    # 3. averaging the gradients and apply updates
                    fp16_trainer.step(1, max_norm=1*num_workers)

                nsp_metric.update([ns_label], [classified])
                mlm_metric.update([masked_id], [decoded], [masked_weight])

                # logging
                if (step_num + 1) % (args.log_interval) == 0 and (batch_num + 1) % accumulate == 0:
                    log(begin_time, running_num_tks, running_mlm_loss / accumulate,
                        running_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric,
                        trainer, args.log_interval)
                    begin_time = time.time()
                    running_mlm_loss = running_nsp_loss = running_num_tks = 0
                    mlm_metric.reset_local()
                    nsp_metric.reset_local()

                # saving checkpoints
                if (step_num + 1) % args.ckpt_interval == 0 and (batch_num + 1) % accumulate == 0:
                    if is_master_node:
                        save_states(step_num, trainer, args.ckpt_dir, local_rank)
                        if local_rank == 0:
                            save_parameters(step_num, model, args.ckpt_dir)
                    if data_eval:
                        # eval data is always based on a fixed npz file.
                        dataset_eval = get_pretrain_data_npz(data_eval, args.batch_size_eval, 1,
                                                             False, False, 1)
                        evaluate(dataset_eval, model, nsp_loss, mlm_loss, len(vocab), [ctx],
                                 args.log_interval, args.dtype)

                batch_num += 1

    if is_master_node:
        save_states(step_num, trainer, args.ckpt_dir, local_rank)
        if local_rank == 0:
            save_parameters(step_num, model, args.ckpt_dir)
    mx.nd.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
Beispiel #2
0
def train(data_train, model, nsp_loss, mlm_loss, vocab_size, ctx, store):
    """Training function."""
    mlm_metric = nlp.metric.MaskedAccuracy()
    nsp_metric = nlp.metric.MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    lr = args.lr
    optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01}
    if args.dtype == 'float16':
        optim_params['multi_precision'] = True

    trainer = mx.gluon.Trainer(model.collect_params(),
                               'bertadam',
                               optim_params,
                               update_on_kvstore=False,
                               kvstore=store)
    dynamic_loss_scale = args.dtype == 'float16'
    fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale)

    if args.start_step:
        state_path = os.path.join(args.ckpt_dir,
                                  '%07d.states.%02d' % (args.start_step, 0))
        logging.info('Loading trainer state from %s', state_path)
        nlp.utils.load_states(trainer, state_path)

    accumulate = args.accumulate
    num_train_steps = args.num_steps
    warmup_ratio = args.warmup_ratio
    num_warmup_steps = int(num_train_steps * warmup_ratio)
    params = [
        p for p in model.collect_params().values() if p.grad_req != 'null'
    ]
    param_dict = model.collect_params()

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    if accumulate > 1:
        for p in params:
            p.grad_req = 'add'

    train_begin_time = time.time()
    begin_time = time.time()
    running_mlm_loss = running_nsp_loss = running_num_tks = 0
    batch_num = 0
    step_num = args.start_step

    parallel_model = ParallelBERT(model,
                                  mlm_loss,
                                  nsp_loss,
                                  vocab_size,
                                  store.num_workers * accumulate,
                                  trainer=fp16_trainer)
    num_ctxes = len(ctx)
    parallel = nlp.utils.Parallel(num_ctxes if num_ctxes > 1 else 0,
                                  parallel_model)

    while step_num < num_train_steps:
        for _, dataloader in enumerate(data_train):
            if step_num >= num_train_steps:
                break

            # create dummy data loader if needed
            if args.dummy_data_len:
                target_shape = (args.batch_size * num_ctxes,
                                args.dummy_data_len)
                dataloader = get_dummy_dataloader(dataloader, target_shape)

            for _, data_batch in enumerate(dataloader):
                if step_num >= num_train_steps:
                    break
                if batch_num % accumulate == 0:
                    step_num += 1
                    # if accumulate > 1, grad_req is set to 'add', and zero_grad is required
                    if accumulate > 1:
                        param_dict.zero_grad()
                    # update learning rate
                    if step_num <= num_warmup_steps:
                        new_lr = lr * step_num / num_warmup_steps
                    else:
                        offset = lr * step_num / num_train_steps
                        new_lr = lr - offset
                    trainer.set_learning_rate(new_lr)
                    if args.profile:
                        profile(step_num, 10, 12, profile_name=args.profile)
                if args.use_avg_len:
                    data_list = [[seq.as_in_context(context) for seq in shard]
                                 for context, shard in zip(ctx, data_batch)]
                else:
                    if data_batch[0].shape[0] < len(ctx):
                        continue
                    data_list = split_and_load(data_batch, ctx)

                ns_label_list, ns_pred_list = [], []
                mask_label_list, mask_pred_list, mask_weight_list = [], [], []

                # parallel forward / backward
                for data in data_list:
                    parallel.put(data)
                for _ in range(len(ctx)):
                    (_, next_sentence_label, classified, masked_id, decoded,
                     masked_weight, ls1, ls2, valid_length) = parallel.get()
                    ns_label_list.append(next_sentence_label)
                    ns_pred_list.append(classified)
                    mask_label_list.append(masked_id)
                    mask_pred_list.append(decoded)
                    mask_weight_list.append(masked_weight)
                    running_mlm_loss += ls1.as_in_context(mx.cpu()) / num_ctxes
                    running_nsp_loss += ls2.as_in_context(mx.cpu()) / num_ctxes
                    running_num_tks += valid_length.sum().as_in_context(
                        mx.cpu())

                # update
                if (batch_num + 1) % accumulate == 0:
                    fp16_trainer.step(1, max_norm=1)
                nsp_metric.update(ns_label_list, ns_pred_list)
                mlm_metric.update(mask_label_list, mask_pred_list,
                                  mask_weight_list)
                # logging
                if (step_num + 1) % (args.log_interval) == 0 and (
                        batch_num + 1) % accumulate == 0:
                    log(begin_time, running_num_tks,
                        running_mlm_loss / accumulate,
                        running_nsp_loss / accumulate, step_num, mlm_metric,
                        nsp_metric, trainer, args.log_interval)
                    begin_time = time.time()
                    running_mlm_loss = running_nsp_loss = running_num_tks = 0
                    mlm_metric.reset_local()
                    nsp_metric.reset_local()

                # saving checkpoints
                if (step_num + 1) % args.ckpt_interval == 0 \
                   and (batch_num + 1) % accumulate == 0 and store.rank == 0:
                    save_states(step_num, trainer, args.ckpt_dir)
                    save_parameters(step_num, model, args.ckpt_dir)
                batch_num += 1
    if store.rank == 0:
        save_states(step_num, trainer, args.ckpt_dir)
        save_parameters(step_num, model, args.ckpt_dir)
    mx.nd.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f}s'.format(train_end_time -
                                             train_begin_time))
def train(data_train, data_eval, model):
    """Training function."""
    # backend specific implementation
    param_dict = model.bert.collect_params()
    if backend == 'horovod':
        hvd.broadcast_parameters(param_dict, root_rank=0)

    mlm_metric = nlp.metric.MaskedAccuracy()
    nsp_metric = nlp.metric.MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    logging.info('Creating distributed trainer...')
    lr = args.lr
    optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01}
    if args.dtype == 'float16':
        optim_params['multi_precision'] = True

    dynamic_loss_scale = args.dtype == 'float16'
    if dynamic_loss_scale:
        loss_scale_param = {
            'scale_window': 2000 / num_workers,
            'init_scale': 2**10
        }
    else:
        loss_scale_param = None

    # backend specific implementation
    if backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer,
                                         optim_params)
    else:
        trainer = mx.gluon.Trainer(param_dict,
                                   args.optimizer,
                                   optim_params,
                                   update_on_kvstore=False)
    fp16_trainer = FP16Trainer(trainer,
                               dynamic_loss_scale=dynamic_loss_scale,
                               loss_scaler_params=loss_scale_param)

    if args.start_step:
        state_path = os.path.join(
            args.ckpt_dir, '%07d.states.%02d' % (args.start_step, local_rank))
        logging.info('Loading trainer state from %s', state_path)
        nlp.utils.load_states(trainer, state_path)

    accumulate = args.accumulate
    num_train_steps = args.num_steps
    warmup_ratio = args.warmup_ratio
    num_warmup_steps = int(num_train_steps * warmup_ratio)
    params = [p for p in param_dict.values() if p.grad_req != 'null']

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    if accumulate > 1:
        for p in params:
            p.grad_req = 'add'

    train_begin_time = time.time()
    begin_time = time.time()
    running_mlm_loss, running_nsp_loss = 0, 0
    running_num_tks = 0
    batch_num = 0
    step_num = args.start_step

    if args.phase2:
        step_num -= args.phase1_num_steps

    logging.info('Training started')

    # create dummy data loader if needed
    parallel_model = DataParallelBERT(model, trainer=fp16_trainer)
    num_ctxes = len(ctxs)
    parallel = nlp.utils.Parallel(num_ctxes if num_ctxes > 1 else 0,
                                  parallel_model)

    while step_num < num_train_steps:

        data_train_iter = iter(data_train)
        end_of_batch = False
        next_data_batch = next(data_train_iter)
        while not end_of_batch:
            data_batch = next_data_batch
            if step_num >= num_train_steps:
                break
            if batch_num % accumulate == 0:
                step_num += 1
                # update learning rate
                if step_num <= num_warmup_steps:
                    new_lr = lr * step_num / num_warmup_steps
                else:
                    offset = (num_train_steps - step_num) / (num_train_steps -
                                                             num_warmup_steps)
                    new_lr = lr * max(offset, 0)
                trainer.set_learning_rate(new_lr)
                if args.profile:
                    profile(step_num,
                            10,
                            14,
                            profile_name=args.profile + str(rank))

            # load data
            data_list = list(split_and_load(data_batch, ctxs))

            ns_label_list, ns_pred_list = [], []
            mask_label_list, mask_pred_list, mask_weight_list = [], [], []

            num_data = len(data_list)
            for i in range(num_data):
                parallel.put(data_list[i])
            for _ in range(num_data):
                (next_sentence_label, classified, masked_id, decoded,
                 masked_weight, ls1, ls2, valid_length) = parallel.get()
                ns_label_list.append(next_sentence_label)
                ns_pred_list.append(classified)
                mask_label_list.append(masked_id)
                mask_pred_list.append(decoded)
                mask_weight_list.append(masked_weight)
                running_mlm_loss += ls1.as_in_context(mx.cpu()) / len(ctxs)
                running_nsp_loss += ls2.as_in_context(mx.cpu()) / len(ctxs)
                running_num_tks += valid_length.sum().as_in_context(mx.cpu())
            # pre fetch next batch
            try:
                next_data_batch = next(data_train_iter)
            except StopIteration:
                end_of_batch = True

            # update
            if (batch_num + 1) % accumulate == 0:
                fp16_trainer.step(1, max_norm=1.0 * num_workers)
                if accumulate > 1:
                    param_dict.zero_grad()
            # update metrics
            if args.no_compute_acc:
                mask_pred_list[0].wait_to_read()
            else:
                nsp_metric.update(ns_label_list, ns_pred_list)
                mlm_metric.update(mask_label_list, mask_pred_list,
                                  mask_weight_list)

            # logging
            if step_num % (args.log_interval) == 0 and (batch_num +
                                                        1) % accumulate == 0:
                if args.no_compute_acc:
                    log_noacc(begin_time, running_num_tks,
                              running_mlm_loss / accumulate,
                              running_nsp_loss / accumulate, step_num, trainer,
                              args.log_interval)
                else:
                    log(begin_time, running_num_tks,
                        running_mlm_loss / accumulate,
                        running_nsp_loss / accumulate, step_num, mlm_metric,
                        nsp_metric, trainer, args.log_interval)
                    mlm_metric.reset_local()
                    nsp_metric.reset_local()
                begin_time = time.time()
                running_mlm_loss = running_nsp_loss = running_num_tks = 0

            # saving checkpoints
            if step_num % args.ckpt_interval == 0 and (batch_num +
                                                       1) % accumulate == 0:
                if is_master_node:
                    save_states(step_num, trainer, args.ckpt_dir, local_rank)
                    if local_rank == 0:
                        save_parameters(step_num, model.bert, args.ckpt_dir)
            if step_num % args.eval_interval == 0 and data_eval \
                    and (batch_num + 1) % accumulate == 0:
                # eval data is always based on a fixed npz file.
                dataset_eval = get_pretrain_data_npz(data_eval,
                                                     batch_size_eval, 1, False,
                                                     1, vocab)
                evaluate(dataset_eval, model, ctxs, args.log_interval,
                         args.dtype)

            batch_num += 1

    if is_master_node:
        save_states(step_num, trainer, args.ckpt_dir, local_rank)
        if local_rank == 0:
            save_parameters(step_num, model.bert, args.ckpt_dir)
    mx.nd.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f}s'.format(train_end_time -
                                             train_begin_time))
Beispiel #4
0
def train(data_train, data_eval, model):
    """Training function."""
    # backend specific implementation
    param_dict = model.bert.collect_params()
    if backend == 'horovod':
        hvd.broadcast_parameters(param_dict, root_rank=0)

    mlm_metric = nlp.metric.MaskedAccuracy()
    nsp_metric = nlp.metric.MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    logging.debug('Creating distributed trainer...')
    lr = args.lr
    optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01}
    if args.dtype == 'float16':
        optim_params['multi_precision'] = True
    if args.optimizer == 'lamb':
        optim_params['bias_correction'] = True

    dynamic_loss_scale = args.dtype == 'float16'
    if dynamic_loss_scale:
        loss_scale_param = {
            'scale_window': 2000 / num_workers,
            'init_scale': 1
        }
    else:
        loss_scale_param = None

    # backend specific implementation
    if backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer,
                                         optim_params)
    elif backend == 'byteps':
        trainer = bps.DistributedTrainer(param_dict, args.optimizer,
                                         optim_params)
    else:
        trainer = mx.gluon.Trainer(param_dict,
                                   args.optimizer,
                                   optim_params,
                                   update_on_kvstore=False)
    fp16_trainer = FP16Trainer(trainer,
                               dynamic_loss_scale=dynamic_loss_scale,
                               loss_scaler_params=loss_scale_param)

    if args.start_step:
        state_path = os.path.join(
            args.ckpt_dir, '%07d.states.%02d' % (args.start_step, local_rank))
        logging.info('Loading trainer state from %s', state_path)
        nlp.utils.load_states(trainer, state_path)

    accumulate = args.accumulate
    num_train_steps = args.num_steps
    warmup_ratio = args.warmup_ratio
    num_warmup_steps = int(num_train_steps * warmup_ratio)
    params = [p for p in param_dict.values() if p.grad_req != 'null']

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    if accumulate > 1:
        for p in params:
            p.grad_req = 'add'

    train_begin_time = time.time()
    begin_time = time.time()
    running_mlm_loss, running_nsp_loss = 0, 0
    local_mlm_loss, local_num_masks = 0, mx.nd.array([0], ctx=ctxs[0])
    running_num_tks = 0
    batch_num = 0
    step_num = args.start_step

    logging.debug('Training started')
    logging.info(
        'Generating the first batch of data, which may take a few minutes ...')

    # create dummy data loader if needed
    parallel_model = DataParallelBERT(model, trainer=fp16_trainer)
    num_ctxes = len(ctxs)
    parallel = nlp.utils.Parallel(num_ctxes if num_ctxes > 1 else 0,
                                  parallel_model)

    if backend == 'byteps':
        bps.byteps_declare_tensor("local_num_masks")
        bps.byteps_push_pull(local_num_masks,
                             is_average=False,
                             name="local_num_masks",
                             priority=0)
        logging.debug('Broadcast local_num_masks tensor')
        next_batch = next(
            iter(
                get_dummy_dataloader(batch_size, args.max_seq_length,
                                     args.max_predictions_per_seq)))
        data_list = list(split_and_load(next_batch, ctxs))
        parallel.put(data_list[0])
        parallel.get()
        trainer._init_params()

    while step_num < num_train_steps:

        data_train_iter = iter(data_train)
        end_of_batch = False
        next_data_batch = next(data_train_iter)
        while not end_of_batch:
            data_batch = next_data_batch
            if step_num >= num_train_steps:
                break
            if batch_num % accumulate == 0:
                step_num += 1
                # if accumulate > 1, grad_req is set to 'add', and zero_grad is required
                if accumulate > 1:
                    param_dict.zero_grad()
                # update learning rate
                if step_num <= num_warmup_steps:
                    new_lr = lr * step_num / num_warmup_steps
                else:
                    offset = lr * step_num / num_train_steps
                    new_lr = lr - offset
                trainer.set_learning_rate(new_lr)
                if args.profile:
                    profile(step_num,
                            10,
                            14,
                            profile_name=args.profile + str(rank))
                if early_stop and step_num == 10:
                    mx.nd.waitall()
                    exit()

            # load data
            data_list = list(split_and_load(data_batch, ctxs))

            ns_label_list, ns_pred_list = [], []
            mask_label_list, mask_pred_list, mask_weight_list = [], [], []

            with mx.autograd.record():
                num_data = len(data_list)
                for i in range(num_data):
                    parallel.put(data_list[i])
                for _ in range(num_data):
                    (next_sentence_label, classified, masked_id, decoded,
                     masked_weight, ls1, ls2, valid_length,
                     num_masks) = parallel.get()
                    ns_label_list.append(next_sentence_label)
                    ns_pred_list.append(classified)
                    mask_label_list.append(masked_id)
                    mask_pred_list.append(decoded)
                    mask_weight_list.append(masked_weight)
                    local_num_masks += num_masks
                    local_mlm_loss += ls1
                    running_num_tks += valid_length.sum()
            # pre fetch next batch
            try:
                next_data_batch = next(data_train_iter)
            except StopIteration:
                end_of_batch = True

            # update
            if (batch_num + 1) % accumulate == 0:
                running_mlm_loss += local_mlm_loss / local_num_masks
                if backend == 'horovod':
                    hvd.allreduce_(local_num_masks,
                                   average=False,
                                   name='local_num_masks')
                elif backend == 'byteps':
                    bps.byteps_push_pull(local_num_masks,
                                         is_average=False,
                                         name="local_num_masks",
                                         priority=0)
                # because byteps implicitly set scale /= num_workers
                fp16_trainer.step(local_num_masks * num_workers,
                                  max_norm=local_num_masks,
                                  num_ctxs=len(ctxs) * num_workers)
                local_num_masks, local_mlm_loss = 0, 0
            # update metrics
            if args.no_compute_acc:
                for mask_pred_i in mask_pred_list:
                    mask_pred_i.wait_to_read()
            else:
                nsp_metric.update(ns_label_list, ns_pred_list)
                mlm_metric.update(mask_label_list, mask_pred_list,
                                  mask_weight_list)

            # logging
            if (step_num + 1) % (args.log_interval) == 0 and (
                    batch_num + 1) % accumulate == 0:
                if args.no_compute_acc:
                    log_noacc(begin_time, running_num_tks, running_mlm_loss, 0,
                              step_num, trainer, args.log_interval)
                else:
                    log(begin_time, running_num_tks,
                        running_mlm_loss / accumulate,
                        running_nsp_loss / accumulate, step_num, mlm_metric,
                        nsp_metric, trainer, args.log_interval)
                    mlm_metric.reset_local()
                    nsp_metric.reset_local()
                begin_time = time.time()
                running_mlm_loss = running_nsp_loss = running_num_tks = 0

            # saving checkpoints
            if (step_num + 1) % args.ckpt_interval == 0 and (
                    batch_num + 1) % accumulate == 0:
                if is_master_node:
                    save_states(step_num, trainer, args.ckpt_dir, local_rank)
                    if local_rank == 0:
                        save_parameters(step_num, model.bert, args.ckpt_dir)
                if (step_num + 1) % args.eval_interval == 0 and data_eval:
                    # eval data is always based on a fixed npz file.
                    dataset_eval = get_pretrain_data_npz(
                        data_eval, batch_size_eval, 1, False, 1, vocab)
                    evaluate(dataset_eval, model, ctxs, args.log_interval,
                             args.dtype, rank, num_workers)

            batch_num += 1

    if is_master_node:
        save_states(step_num, trainer, args.ckpt_dir, local_rank)
        if local_rank == 0:
            save_parameters(step_num, model, args.ckpt_dir)
    mx.nd.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f}s'.format(train_end_time -
                                             train_begin_time))