def train(data_train, data_eval, model, nsp_loss, mlm_loss, vocab_size, ctx):
    """Training function."""
    hvd.broadcast_parameters(model.collect_params(), root_rank=0)

    mlm_metric = nlp.metric.MaskedAccuracy()
    nsp_metric = nlp.metric.MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    logging.debug('Creating distributed trainer...')
    lr = args.lr
    optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01}
    if args.dtype == 'float16':
        optim_params['multi_precision'] = True

    dynamic_loss_scale = args.dtype == 'float16'
    if dynamic_loss_scale:
        loss_scale_param = {'scale_window': 2000 / num_workers}
    else:
        loss_scale_param = None
    trainer = hvd.DistributedTrainer(model.collect_params(), 'bertadam', optim_params)
    fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale,
                               loss_scaler_params=loss_scale_param)

    if args.start_step:
        state_path = os.path.join(args.ckpt_dir, '%07d.states.%02d'%(args.start_step, local_rank))
        logging.info('Loading trainer state from %s', state_path)
        nlp.utils.load_states(trainer, state_path)

    accumulate = args.accumulate
    num_train_steps = args.num_steps
    warmup_ratio = args.warmup_ratio
    num_warmup_steps = int(num_train_steps * warmup_ratio)
    params = [p for p in model.collect_params().values() if p.grad_req != 'null']
    param_dict = model.collect_params()

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    if accumulate > 1:
        for p in params:
            p.grad_req = 'add'

    train_begin_time = time.time()
    begin_time = time.time()
    running_mlm_loss, running_nsp_loss = 0, 0
    running_num_tks = 0
    batch_num = 0
    step_num = args.start_step

    logging.debug('Training started')
    while step_num < num_train_steps:
        for _, dataloader in enumerate(data_train):
            if step_num >= num_train_steps:
                break

            # create dummy data loader if needed
            if args.dummy_data_len:
                target_shape = (args.batch_size, args.dummy_data_len)
                dataloader = get_dummy_dataloader(dataloader, target_shape)

            for _, data_batch in enumerate(dataloader):
                if step_num >= num_train_steps:
                    break
                if batch_num % accumulate == 0:
                    step_num += 1
                    # if accumulate > 1, grad_req is set to 'add', and zero_grad is required
                    if accumulate > 1:
                        param_dict.zero_grad()
                    # update learning rate
                    if step_num <= num_warmup_steps:
                        new_lr = lr * step_num / num_warmup_steps
                    else:
                        offset = lr * step_num / num_train_steps
                        new_lr = lr - offset
                    trainer.set_learning_rate(new_lr)
                    if args.profile:
                        profile(step_num, 10, 14, profile_name=args.profile + str(rank))

                # load data
                if args.use_avg_len:
                    data_list = [[seq.as_in_context(context) for seq in shard]
                                 for context, shard in zip([ctx], data_batch)]
                else:
                    data_list = list(split_and_load(data_batch, [ctx]))
                data = data_list[0]

                # forward
                with mx.autograd.record():
                    (ls, ns_label, classified, masked_id, decoded, \
                     masked_weight, ls1, ls2, valid_len) = forward(data, model, mlm_loss,
                                                                   nsp_loss, vocab_size, args.dtype)
                    ls = ls / accumulate
                    # backward
                    if args.dtype == 'float16':
                        fp16_trainer.backward(ls)
                    else:
                        ls.backward()

                running_mlm_loss += ls1.as_in_context(mx.cpu())
                running_nsp_loss += ls2.as_in_context(mx.cpu())
                running_num_tks += valid_len.sum().as_in_context(mx.cpu())

                # update
                if (batch_num + 1) % accumulate == 0:
                    # step() performs 3 things:
                    # 1. allreduce gradients from all workers
                    # 2. checking the global_norm of gradients and clip them if necessary
                    # 3. averaging the gradients and apply updates
                    fp16_trainer.step(1, max_norm=1*num_workers)

                nsp_metric.update([ns_label], [classified])
                mlm_metric.update([masked_id], [decoded], [masked_weight])

                # logging
                if (step_num + 1) % (args.log_interval) == 0 and (batch_num + 1) % accumulate == 0:
                    log(begin_time, running_num_tks, running_mlm_loss / accumulate,
                        running_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric,
                        trainer, args.log_interval)
                    begin_time = time.time()
                    running_mlm_loss = running_nsp_loss = running_num_tks = 0
                    mlm_metric.reset_local()
                    nsp_metric.reset_local()

                # saving checkpoints
                if (step_num + 1) % args.ckpt_interval == 0 and (batch_num + 1) % accumulate == 0:
                    if is_master_node:
                        save_states(step_num, trainer, args.ckpt_dir, local_rank)
                        if local_rank == 0:
                            save_parameters(step_num, model, args.ckpt_dir)
                    if data_eval:
                        # eval data is always based on a fixed npz file.
                        dataset_eval = get_pretrain_data_npz(data_eval, args.batch_size_eval, 1,
                                                             False, False, 1)
                        evaluate(dataset_eval, model, nsp_loss, mlm_loss, len(vocab), [ctx],
                                 args.log_interval, args.dtype)

                batch_num += 1

    if is_master_node:
        save_states(step_num, trainer, args.ckpt_dir, local_rank)
        if local_rank == 0:
            save_parameters(step_num, model, args.ckpt_dir)
    mx.nd.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
    logging.debug('Random seed set to %d', random_seed)
    mx.random.seed(random_seed)

    if args.data:
        if args.raw:
            get_dataset_fn = functools.partial(get_pretrain_data_text,
                                               max_seq_length=args.max_seq_length,
                                               short_seq_prob=args.short_seq_prob,
                                               masked_lm_prob=args.masked_lm_prob,
                                               max_predictions_per_seq=args.max_predictions_per_seq,
                                               whole_word_mask=args.whole_word_mask,
                                               vocab=vocab, tokenizer=tokenizer,
                                               num_workers=args.num_data_workers)
        else:
            get_dataset_fn = get_pretrain_data_npz

        num_parts = 1 if args.dummy_data_len else num_workers
        part_idx = 0 if args.dummy_data_len else rank
        data_train = get_dataset_fn(args.data, args.batch_size, 1, True,
                                    args.use_avg_len, args.num_buckets,
                                    num_parts=num_parts, part_idx=part_idx,
                                    prefetch=not args.dummy_data_len)
        train(data_train, data_eval, model, nsp_loss, mlm_loss, len(vocab), ctx)
    if data_eval:
        # eval data is always based on a fixed npz file.
        dataset_eval = get_pretrain_data_npz(data_eval, args.batch_size_eval, 1,
                                             False, False, 1)
        evaluate(dataset_eval, model, nsp_loss, mlm_loss, len(vocab), [ctx],
                 args.log_interval, args.dtype)
def train(data_train, data_eval, model):
    """Training function."""
    # backend specific implementation
    param_dict = model.bert.collect_params()
    if backend == 'horovod':
        hvd.broadcast_parameters(param_dict, root_rank=0)

    mlm_metric = nlp.metric.MaskedAccuracy()
    nsp_metric = nlp.metric.MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    logging.info('Creating distributed trainer...')
    lr = args.lr
    optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01}
    if args.dtype == 'float16':
        optim_params['multi_precision'] = True

    dynamic_loss_scale = args.dtype == 'float16'
    if dynamic_loss_scale:
        loss_scale_param = {
            'scale_window': 2000 / num_workers,
            'init_scale': 2**10
        }
    else:
        loss_scale_param = None

    # backend specific implementation
    if backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer,
                                         optim_params)
    else:
        trainer = mx.gluon.Trainer(param_dict,
                                   args.optimizer,
                                   optim_params,
                                   update_on_kvstore=False)
    fp16_trainer = FP16Trainer(trainer,
                               dynamic_loss_scale=dynamic_loss_scale,
                               loss_scaler_params=loss_scale_param)

    if args.start_step:
        state_path = os.path.join(
            args.ckpt_dir, '%07d.states.%02d' % (args.start_step, local_rank))
        logging.info('Loading trainer state from %s', state_path)
        nlp.utils.load_states(trainer, state_path)

    accumulate = args.accumulate
    num_train_steps = args.num_steps
    warmup_ratio = args.warmup_ratio
    num_warmup_steps = int(num_train_steps * warmup_ratio)
    params = [p for p in param_dict.values() if p.grad_req != 'null']

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    if accumulate > 1:
        for p in params:
            p.grad_req = 'add'

    train_begin_time = time.time()
    begin_time = time.time()
    running_mlm_loss, running_nsp_loss = 0, 0
    running_num_tks = 0
    batch_num = 0
    step_num = args.start_step

    if args.phase2:
        step_num -= args.phase1_num_steps

    logging.info('Training started')

    # create dummy data loader if needed
    parallel_model = DataParallelBERT(model, trainer=fp16_trainer)
    num_ctxes = len(ctxs)
    parallel = nlp.utils.Parallel(num_ctxes if num_ctxes > 1 else 0,
                                  parallel_model)

    while step_num < num_train_steps:

        data_train_iter = iter(data_train)
        end_of_batch = False
        next_data_batch = next(data_train_iter)
        while not end_of_batch:
            data_batch = next_data_batch
            if step_num >= num_train_steps:
                break
            if batch_num % accumulate == 0:
                step_num += 1
                # update learning rate
                if step_num <= num_warmup_steps:
                    new_lr = lr * step_num / num_warmup_steps
                else:
                    offset = (num_train_steps - step_num) / (num_train_steps -
                                                             num_warmup_steps)
                    new_lr = lr * max(offset, 0)
                trainer.set_learning_rate(new_lr)
                if args.profile:
                    profile(step_num,
                            10,
                            14,
                            profile_name=args.profile + str(rank))

            # load data
            data_list = list(split_and_load(data_batch, ctxs))

            ns_label_list, ns_pred_list = [], []
            mask_label_list, mask_pred_list, mask_weight_list = [], [], []

            num_data = len(data_list)
            for i in range(num_data):
                parallel.put(data_list[i])
            for _ in range(num_data):
                (next_sentence_label, classified, masked_id, decoded,
                 masked_weight, ls1, ls2, valid_length) = parallel.get()
                ns_label_list.append(next_sentence_label)
                ns_pred_list.append(classified)
                mask_label_list.append(masked_id)
                mask_pred_list.append(decoded)
                mask_weight_list.append(masked_weight)
                running_mlm_loss += ls1.as_in_context(mx.cpu()) / len(ctxs)
                running_nsp_loss += ls2.as_in_context(mx.cpu()) / len(ctxs)
                running_num_tks += valid_length.sum().as_in_context(mx.cpu())
            # pre fetch next batch
            try:
                next_data_batch = next(data_train_iter)
            except StopIteration:
                end_of_batch = True

            # update
            if (batch_num + 1) % accumulate == 0:
                fp16_trainer.step(1, max_norm=1.0 * num_workers)
                if accumulate > 1:
                    param_dict.zero_grad()
            # update metrics
            if args.no_compute_acc:
                mask_pred_list[0].wait_to_read()
            else:
                nsp_metric.update(ns_label_list, ns_pred_list)
                mlm_metric.update(mask_label_list, mask_pred_list,
                                  mask_weight_list)

            # logging
            if step_num % (args.log_interval) == 0 and (batch_num +
                                                        1) % accumulate == 0:
                if args.no_compute_acc:
                    log_noacc(begin_time, running_num_tks,
                              running_mlm_loss / accumulate,
                              running_nsp_loss / accumulate, step_num, trainer,
                              args.log_interval)
                else:
                    log(begin_time, running_num_tks,
                        running_mlm_loss / accumulate,
                        running_nsp_loss / accumulate, step_num, mlm_metric,
                        nsp_metric, trainer, args.log_interval)
                    mlm_metric.reset_local()
                    nsp_metric.reset_local()
                begin_time = time.time()
                running_mlm_loss = running_nsp_loss = running_num_tks = 0

            # saving checkpoints
            if step_num % args.ckpt_interval == 0 and (batch_num +
                                                       1) % accumulate == 0:
                if is_master_node:
                    save_states(step_num, trainer, args.ckpt_dir, local_rank)
                    if local_rank == 0:
                        save_parameters(step_num, model.bert, args.ckpt_dir)
            if step_num % args.eval_interval == 0 and data_eval \
                    and (batch_num + 1) % accumulate == 0:
                # eval data is always based on a fixed npz file.
                dataset_eval = get_pretrain_data_npz(data_eval,
                                                     batch_size_eval, 1, False,
                                                     1, vocab)
                evaluate(dataset_eval, model, ctxs, args.log_interval,
                         args.dtype)

            batch_num += 1

    if is_master_node:
        save_states(step_num, trainer, args.ckpt_dir, local_rank)
        if local_rank == 0:
            save_parameters(step_num, model.bert, args.ckpt_dir)
    mx.nd.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f}s'.format(train_end_time -
                                             train_begin_time))
                num_max_dataset_cached=args.num_max_dataset_cached)
        else:
            get_dataset_fn = get_pretrain_data_npz

        if args.synthetic_data:
            data_train = get_dummy_dataloader(batch_size, args.max_seq_length,
                                              args.max_predictions_per_seq)
        else:
            shuffle = True
            logging.info(
                'args.num_buckets: {}, num_workers: {}, rank: {}'.format(
                    args.num_buckets, num_workers, rank))
            data_train = get_dataset_fn(
                args.data,
                batch_size,
                len(ctxs),
                shuffle,
                args.num_buckets,
                vocab,
                num_parts=num_workers,
                part_idx=rank,
                num_dataset_workers=args.num_dataset_workers,
                num_batch_workers=args.num_batch_workers)
        train(data_train, data_eval, model)
    if data_eval:
        # eval data is always based on a fixed npz file.
        shuffle = False
        dataset_eval = get_pretrain_data_npz(data_eval, batch_size_eval,
                                             len(ctxs), shuffle, 1, vocab)
        evaluate(dataset_eval, model, ctxs, args.log_interval, args.dtype)
def train(data_train, data_eval, model):
    """Training function."""
    # backend specific implementation
    param_dict = model.bert.collect_params()
    if backend == 'horovod':
        hvd.broadcast_parameters(param_dict, root_rank=0)

    mlm_metric = nlp.metric.MaskedAccuracy()
    nsp_metric = nlp.metric.MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    logging.debug('Creating distributed trainer...')
    lr = args.lr
    optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01}
    if args.dtype == 'float16':
        optim_params['multi_precision'] = True
    if args.optimizer == 'lamb':
        optim_params['bias_correction'] = True

    dynamic_loss_scale = args.dtype == 'float16'
    if dynamic_loss_scale:
        loss_scale_param = {'scale_window': 2000 / num_workers, 'init_scale': 1}
    else:
        loss_scale_param = None

    # backend specific implementation
    if backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optim_params)
    elif backend == 'byteps':
        trainer = bps.DistributedTrainer(param_dict, args.optimizer, optim_params)
    else:
        trainer = mx.gluon.Trainer(param_dict, args.optimizer, optim_params,
                                   update_on_kvstore=False)
    fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale,
                               loss_scaler_params=loss_scale_param)

    if args.start_step:
        state_path = os.path.join(args.ckpt_dir, '%07d.states.%02d'%(args.start_step, local_rank))
        logging.info('Loading trainer state from %s', state_path)
        nlp.utils.load_states(trainer, state_path)

    accumulate = args.accumulate
    num_train_steps = args.num_steps
    warmup_ratio = args.warmup_ratio
    num_warmup_steps = int(num_train_steps * warmup_ratio)
    params = [p for p in param_dict.values() if p.grad_req != 'null']

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    if accumulate > 1:
        for p in params:
            p.grad_req = 'add'

    train_begin_time = time.time()
    begin_time = time.time()
    running_mlm_loss, running_nsp_loss = 0, 0
    local_mlm_loss, local_num_masks = 0, mx.nd.array([0], ctx=ctxs[0])
    running_num_tks = 0
    batch_num = 0
    step_num = args.start_step

    logging.debug('Training started')
    logging.info('Generating the first batch of data, which may take a few minutes ...')

    # create dummy data loader if needed
    parallel_model = DataParallelBERT(model, trainer=fp16_trainer)
    num_ctxes = len(ctxs)
    parallel = nlp.utils.Parallel(num_ctxes if num_ctxes > 1 else 0, parallel_model)

    if backend == 'byteps':
        bps.byteps_declare_tensor("local_num_masks")
        bps.byteps_push_pull(local_num_masks, is_average=False, name="local_num_masks", priority=0)
        logging.debug('Broadcast local_num_masks tensor')
        next_batch = next(iter(get_dummy_dataloader(batch_size, args.max_seq_length, args.max_predictions_per_seq)))
        data_list = list(split_and_load(next_batch, ctxs))
        parallel.put(data_list[0])
        parallel.get()
        trainer._init_params()

    while step_num < num_train_steps:

        data_train_iter = iter(data_train)
        end_of_batch = False
        next_data_batch = next(data_train_iter)
        while not end_of_batch:
            data_batch = next_data_batch
            if step_num >= num_train_steps:
                break
            if batch_num % accumulate == 0:
                step_num += 1
                # if accumulate > 1, grad_req is set to 'add', and zero_grad is required
                if accumulate > 1:
                    param_dict.zero_grad()
                # update learning rate
                if step_num <= num_warmup_steps:
                    new_lr = lr * step_num / num_warmup_steps
                else:
                    offset = lr * step_num / num_train_steps
                    new_lr = lr - offset
                trainer.set_learning_rate(new_lr)
                if args.profile:
                    profile(step_num, 10, 14, profile_name=args.profile + str(rank))
                if early_stop and step_num == 10:
                    mx.nd.waitall()
                    exit()

            # load data
            data_list = list(split_and_load(data_batch, ctxs))

            ns_label_list, ns_pred_list = [], []
            mask_label_list, mask_pred_list, mask_weight_list = [], [], []

            with mx.autograd.record():
                num_data = len(data_list)
                for i in range(num_data):
                    parallel.put(data_list[i])
                for _ in range(num_data):
                    (next_sentence_label, classified, masked_id,
                     decoded, masked_weight, ls1, ls2, valid_length, num_masks) = parallel.get()
                    ns_label_list.append(next_sentence_label)
                    ns_pred_list.append(classified)
                    mask_label_list.append(masked_id)
                    mask_pred_list.append(decoded)
                    mask_weight_list.append(masked_weight)
                    local_num_masks += num_masks
                    local_mlm_loss += ls1
                    running_num_tks += valid_length.sum()
            # pre fetch next batch
            try:
                next_data_batch = next(data_train_iter)
            except StopIteration:
                end_of_batch = True

            # update
            if (batch_num + 1) % accumulate == 0:
                running_mlm_loss += local_mlm_loss / local_num_masks
                if backend == 'horovod':
                    hvd.allreduce_(local_num_masks, average=False, name='local_num_masks')
                elif backend == 'byteps':
                    bps.byteps_push_pull(local_num_masks, is_average=False,
                                         name="local_num_masks", priority=0)
                # because byteps implicitly set scale /= num_workers
                fp16_trainer.step(local_num_masks * num_workers, max_norm=local_num_masks,
                                  num_ctxs=len(ctxs) * num_workers)
                local_num_masks, local_mlm_loss = 0, 0
            # update metrics
            if args.no_compute_acc:
                for mask_pred_i in mask_pred_list:
                    mask_pred_i.wait_to_read()
            else:
                nsp_metric.update(ns_label_list, ns_pred_list)
                mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list)

            # logging
            if (step_num + 1) % (args.log_interval) == 0 and (batch_num + 1) % accumulate == 0:
                if args.no_compute_acc:
                    log_noacc(begin_time, running_num_tks, running_mlm_loss,
                              0, step_num, trainer, args.log_interval)
                else:
                    log(begin_time, running_num_tks, running_mlm_loss / accumulate,
                        running_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric,
                        trainer, args.log_interval)
                    mlm_metric.reset_local()
                    nsp_metric.reset_local()
                begin_time = time.time()
                running_mlm_loss = running_nsp_loss = running_num_tks = 0

            # saving checkpoints
            if (step_num + 1) % args.ckpt_interval == 0 and (batch_num + 1) % accumulate == 0:
#                if is_master_node:
#                    save_states(step_num, trainer, args.ckpt_dir, local_rank)
#                    if local_rank == 0:
#                        save_parameters(step_num, model.bert, args.ckpt_dir)
                if (step_num + 1) % args.eval_interval == 0 and data_eval:
                    # eval data is always based on a fixed npz file.
                    dataset_eval = get_pretrain_data_npz(data_eval, batch_size_eval,
                                                         1, False, 1, vocab)
                    evaluate(dataset_eval, model, ctxs, args.log_interval, args.dtype, rank, num_workers)

            batch_num += 1

#    if is_master_node:
#        save_states(step_num, trainer, args.ckpt_dir, local_rank)
#        if local_rank == 0:
#            save_parameters(step_num, model, args.ckpt_dir)
    mx.nd.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
        if args.raw:
            get_dataset_fn = functools.partial(get_pretrain_data_text,
                                               max_seq_length=args.max_seq_length,
                                               short_seq_prob=args.short_seq_prob,
                                               masked_lm_prob=args.masked_lm_prob,
                                               max_predictions_per_seq=args.max_predictions_per_seq,
                                               whole_word_mask=args.whole_word_mask,
                                               tokenizer=tokenizer)
        else:
            get_dataset_fn = get_pretrain_data_npz

        if args.synthetic_data:
            data_train = get_dummy_dataloader(batch_size, args.max_seq_length,
                                              args.max_predictions_per_seq)
        else:
            shuffle = True
            data_train = get_dataset_fn(args.data, batch_size,
                                        len(ctxs), shuffle, args.num_buckets, vocab,
                                        num_parts=num_workers, part_idx=rank,
                                        num_workers=args.num_data_workers)
        train(data_train, data_eval, model)
    if data_eval:
        # eval data is always based on a fixed npz file.
        shuffle = False
        dataset_eval = get_pretrain_data_npz(data_eval, batch_size_eval,
                                             len(ctxs), shuffle, 1, vocab)

        evaluate(dataset_eval, model, ctxs, args.log_interval, args.dtype, local_rank, 8)
    while True:
        time.sleep(999999999)