예제 #1
0
def setup_logging(args, local_rank):
    """
    Setup logging configuration as well as random seed
    """
    logging_config(args.output_dir,
                   name='finetune_squad{}'.format(args.version),# avoid race
                   overwrite_handler=True,
                   console=(local_rank == 0))
    logging.info(args)
    set_seed(args.seed)
    logging.debug('Random seed set to {}'.format(args.seed))
예제 #2
0
def train(args):
    _, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    level = logging.DEBUG if args.verbose else logging.INFO
    logging_config(
        args.ckpt_dir,
        name='pretrain_bert_' + str(rank),  # avoid race
        level=level,
        console=(local_rank == 0))
    logging.info(args)
    logging.debug('Random seed set to {}'.format(args.seed))
    set_seed(args.seed)
    logging.info('Training info: num_buckets: {}, '
                 'num_workers: {}, rank: {}'.format(args.num_buckets,
                                                    num_workers, rank))
    cfg, tokenizer, model = get_pretraining_model(args.model_name, ctx_l)
    if args.start_step:
        logging.info('Restart training from {}'.format(args.start_step))
        parameters_option(args.start_step, model, args.ckpt_dir, 'Loading',
                          ctx_l)
    else:
        model.initialize(ctx=ctx_l)
    model.hybridize()

    if args.raw:
        get_dataset_fn = functools.partial(
            get_pretrain_data_text,
            max_seq_length=args.max_seq_length,
            short_seq_prob=args.short_seq_prob,
            masked_lm_prob=args.masked_lm_prob,
            max_predictions_per_seq=args.max_predictions_per_seq,
            whole_word_mask=args.whole_word_mask,
            random_next_sentence=args.random_next_sentence,
            tokenizer=tokenizer,
            circle_length=args.circle_length,
            repeat=args.repeat,
            dataset_cached=args.dataset_cached,
            num_max_dataset_cached=args.num_max_dataset_cached)
    else:
        get_dataset_fn = get_pretrain_data_npz

    data_train = get_dataset_fn(args.data,
                                args.batch_size,
                                shuffle=True,
                                num_buckets=args.num_buckets,
                                vocab=tokenizer.vocab,
                                num_parts=num_workers,
                                part_idx=rank,
                                num_dataset_workers=args.num_dataset_workers,
                                num_batch_workers=args.num_batch_workers)

    param_dict = model.collect_params()
    # Do not apply weight decay to all the LayerNorm and bias
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    # Set grad_req if gradient accumulation is required
    params = [p for p in param_dict.values() if p.grad_req != 'null']
    num_accumulated = args.num_accumulated
    if num_accumulated > 1:
        logging.info(
            'Using gradient accumulation. Effective global batch size = {}'.
            format(num_accumulated * args.batch_size * len(ctx_l) *
                   num_workers))
        for p in params:
            p.grad_req = 'add'

    num_steps = args.num_steps
    warmup_steps = int(num_steps * args.warmup_ratio)
    log_interval = args.log_interval
    save_interval = args.ckpt_interval
    logging.info(
        '#Total Training Steps={}, Warmup Steps={}, Save Interval={}'.format(
            num_steps, warmup_steps, save_interval))
    optimizer_params = {'learning_rate': args.lr, 'wd': args.wd}
    if args.optimizer == 'adamw':
        optimizer_params.update({
            'beta1': 0.9,
            'beta2': 0.999,
            'epsilon': 1e-6,
            'correct_bias': False,
        })
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer,
                                         optimizer_params)
    elif args.comm_backend == 'byteps':
        trainer = bps.DistributedTrainer(param_dict, args.optimizer,
                                         optimizer_params)
    else:
        trainer = mx.gluon.Trainer(param_dict,
                                   args.optimizer,
                                   optimizer_params,
                                   update_on_kvstore=False)
    if args.start_step:
        logging.info('Restart training from {}'.format(args.start_step))
        states_option(args.start_step, trainer, args.ckpt_dir, local_rank,
                      'Loading')

    # backend specific implementation
    if args.comm_backend == 'byteps':
        trainer._init_params()
    if args.comm_backend == 'horovod':
        # Horovod: fetch and broadcast parameters
        hvd.broadcast_parameters(param_dict, root_rank=0)

    # prepare the loss function
    nsp_loss_fn = mx.gluon.loss.SoftmaxCELoss()
    mlm_loss_fn = mx.gluon.loss.SoftmaxCELoss()
    nsp_loss_fn.hybridize()
    mlm_loss_fn.hybridize()

    mlm_metric = MaskedAccuracy()
    nsp_metric = MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    step_num = args.start_step
    if args.phase2:
        step_num -= args.phase1_num_steps

    running_mlm_loss, running_nsp_loss = 0., 0.
    running_num_tks = 0

    train_start_time = time.time()
    tic = time.time()
    # start training
    train_loop_dataloader = grouper(repeat(data_train), len(ctx_l))
    while step_num < num_steps:
        step_num += 1
        for _ in range(num_accumulated):
            sample_l = next(train_loop_dataloader)
            mlm_loss_l = []
            nsp_loss_l = []
            loss_l = []
            ns_label_list, ns_pred_list = [], []
            mask_label_list, mask_pred_list, mask_weight_list = [], [], []
            for sample, ctx in zip(sample_l, ctx_l):
                # prepare data
                (input_id, masked_id, masked_position, masked_weight, \
                    next_sentence_label, segment_id, valid_length) = sample
                input_id = input_id.as_in_ctx(ctx)
                masked_id = masked_id.as_in_ctx(ctx)
                masked_position = masked_position.as_in_ctx(ctx)
                masked_weight = masked_weight.as_in_ctx(ctx)
                next_sentence_label = next_sentence_label.as_in_ctx(ctx)
                segment_id = segment_id.as_in_ctx(ctx)
                valid_length = valid_length.as_in_ctx(ctx)

                with mx.autograd.record():
                    _, _, nsp_score, mlm_scores = model(
                        input_id, segment_id, valid_length, masked_position)
                    denominator = (masked_weight.sum() +
                                   1e-8) * num_accumulated * len(ctx_l)
                    mlm_scores_r = mx.npx.reshape(mlm_scores, (-5, -1))
                    masked_id_r = masked_id.reshape((-1, ))
                    mlm_loss = mlm_loss_fn(mlm_scores_r, masked_id_r,
                                           masked_weight.reshape(
                                               (-1, 1))).sum() / denominator
                    denominator = num_accumulated * len(ctx_l)
                    nsp_loss = nsp_loss_fn(
                        nsp_score, next_sentence_label).mean() / denominator
                    mlm_loss_l.append(mlm_loss)
                    nsp_loss_l.append(nsp_loss)
                    loss_l.append(mlm_loss + nsp_loss)
                    mask_label_list.append(masked_id_r)
                    mask_pred_list.append(mlm_scores_r)
                    mask_weight_list.append(masked_weight.reshape((-1, )))
                    ns_label_list.append(next_sentence_label)
                    ns_pred_list.append(nsp_score)

                running_num_tks += valid_length.sum().as_in_ctx(mx.cpu())

            for loss in loss_l:
                loss.backward()

            running_mlm_loss += sum([
                ele.as_in_ctx(mx.cpu()) for ele in mlm_loss_l
            ]).asnumpy().item()
            running_nsp_loss += sum([
                ele.as_in_ctx(mx.cpu()) for ele in nsp_loss_l
            ]).asnumpy().item()
            mlm_metric.update(mask_label_list, mask_pred_list,
                              mask_weight_list)
            nsp_metric.update(ns_label_list, ns_pred_list)
        # update
        trainer.allreduce_grads()

        total_norm, ratio, is_finite = clip_grad_global_norm(
            params, args.max_grad_norm * num_workers)
        total_norm = total_norm / num_workers

        # update learning rate
        scheduled_lr = args.lr
        if step_num <= warmup_steps:
            scheduled_lr *= step_num / warmup_steps
        else:
            offset = (num_steps - step_num) / (num_steps - warmup_steps)
            scheduled_lr *= max(offset, 0)
        trainer.set_learning_rate(scheduled_lr)

        if args.comm_backend == 'horovod' or args.comm_backend == 'byteps':
            # Note that horovod.trainer._scale is default to num_workers,
            # thus trainer.update(1) will scale the gradients by 1./num_workers.
            # *num_workers* of Horovod is the number of GPUs.
            trainer.update(1, ignore_stale_grad=True)
        else:
            # gluon.trainer._scale is default to 1.
            # *num_workers* of Trainer is the number of machines.
            trainer.update(num_workers, ignore_stale_grad=True)

        if num_accumulated > 1:
            # set grad to zero for gradient accumulation
            model.zero_grad()

        # saving
        if step_num % save_interval == 0 or step_num >= num_steps:
            states_option(step_num, trainer, args.ckpt_dir, local_rank,
                          'Saving')
            if local_rank == 0:
                parameters_option(step_num, model, args.ckpt_dir, 'Saving')
        # logging
        if step_num % log_interval == 0:
            running_mlm_loss /= log_interval
            running_nsp_loss /= log_interval
            toc = time.time()
            logging.info(
                '[step {}], Loss mlm/nsp={:.5f}/{:.3f}, Acc mlm/nsp={:.3f}/{:.3f}, '
                ' LR={:.7f}, grad_norm={:.4f}. Time cost={:.2f} s,'
                ' Throughput={:.1f}K tks/s, ETA={:.2f} h'.format(
                    step_num, running_mlm_loss, running_nsp_loss,
                    mlm_metric.get()[1],
                    nsp_metric.get()[1], trainer.learning_rate, total_norm,
                    toc - tic,
                    running_num_tks.asnumpy().item() / (toc - tic) / 1000,
                    (num_steps - step_num) /
                    (step_num / (toc - train_start_time)) / 3600))
            mlm_metric.reset()
            nsp_metric.reset()
            tic = time.time()

            running_mlm_loss = 0
            running_nsp_loss = 0
            running_num_tks = 0

    logging.info('Finish training step: %d', step_num)

    mx.npx.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f} s'.format(train_end_time -
                                              train_start_time))

    if local_rank == 0:
        model_name = args.model_name.replace('google', 'gluon')
        save_dir = os.path.join(args.ckpt_dir, model_name)
        final_save(model, save_dir, tokenizer, cfg)
예제 #3
0
    rtd_preds = mx.np.round((mx.np.sign(rtd_scores) + 1) / 2).astype(np.int32)

    mlm_accuracy = accuracy(unmasked_tokens, mlm_preds, masked_weights)
    corrupted_mlm_accuracy = accuracy(unmasked_tokens, corrupted_tokens,
                                      masked_weights)
    rtd_accuracy = accuracy(rtd_labels, rtd_preds, length_masks)
    rtd_precision = accuracy(rtd_labels, rtd_preds, length_masks * rtd_preds)
    rtd_recall = accuracy(rtd_labels, rtd_preds, rtd_labels * rtd_preds)
    rtd_auc = auc(rtd_labels, rtd_probs, length_masks)
    writer.add_scalars(
        'results', {
            'mlm_accuracy': mlm_accuracy.asnumpy().item(),
            'corrupted_mlm_accuracy': corrupted_mlm_accuracy.asnumpy().item(),
            'rtd_accuracy': rtd_accuracy.asnumpy().item(),
            'rtd_precision': rtd_precision.asnumpy().item(),
            'rtd_recall': rtd_recall.asnumpy().item(),
            'rtd_auc': rtd_auc
        }, step_num)


if __name__ == '__main__':
    os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round'
    os.environ['MXNET_USE_FUSION'] = '0'  # Manually disable pointwise fusion
    args = parse_args()
    logging_config(args.output_dir, name='pretrain_owt')
    logging.debug('Random seed set to {}'.format(args.seed))
    logging.info(args)
    set_seed(args.seed)
    if args.do_train:
        train(args)
예제 #4
0
def train(args):
    store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    logging_config(
        args.output_dir,
        name='pretrain_owt_' + str(rank),  # avoid race
        console=(local_rank == 0))
    logging.info(args)
    logging.debug('Random seed set to {}'.format(args.seed))
    set_seed(args.seed)
    logging.info('Training info: num_buckets: {}, '
                 'num_workers: {}, rank: {}'.format(args.num_buckets,
                                                    num_workers, rank))
    cfg, tokenizer, model = get_pretraining_model(args.model_name, ctx_l,
                                                  args.max_seq_length,
                                                  args.hidden_dropout_prob,
                                                  args.attention_dropout_prob,
                                                  args.generator_units_scale,
                                                  args.generator_layers_scale)
    data_masker = ElectraMasker(tokenizer,
                                args.max_seq_length,
                                mask_prob=args.mask_prob,
                                replace_prob=args.replace_prob)
    if args.from_raw_text:
        if args.cached_file_path and not os.path.exists(args.cached_file_path):
            os.mkdir(args.cached_file_path)
        get_dataset_fn = functools.partial(
            get_pretrain_data_text,
            max_seq_length=args.max_seq_length,
            short_seq_prob=args.short_seq_prob,
            tokenizer=tokenizer,
            circle_length=args.circle_length,
            repeat=args.repeat,
            cached_file_path=args.cached_file_path)

        logging.info(
            'Processing and loading the training dataset from raw text.')

    else:
        logging.info('Loading the training dataset from local Numpy file.')
        get_dataset_fn = get_pretrain_data_npz

    data_train = get_dataset_fn(args.data,
                                args.batch_size,
                                shuffle=True,
                                num_buckets=args.num_buckets,
                                vocab=tokenizer.vocab,
                                num_parts=num_workers,
                                part_idx=rank,
                                num_dataset_workers=args.num_dataset_workers,
                                num_batch_workers=args.num_batch_workers)

    logging.info('Creating distributed trainer...')
    param_dict = model.collect_params()
    # Do not apply weight decay to all the LayerNorm and bias
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    # Collect differentiable parameters
    params = [p for p in param_dict.values() if p.grad_req != 'null']
    # Set grad_req if gradient accumulation is required
    num_accumulated = args.num_accumulated
    if num_accumulated > 1:
        logging.info(
            'Using gradient accumulation. Effective global batch size = {}'.
            format(num_accumulated * args.batch_size * len(ctx_l) *
                   num_workers))
        for p in params:
            p.grad_req = 'add'
    # backend specific implementation
    if args.comm_backend == 'horovod':
        # Horovod: fetch and broadcast parameters
        hvd.broadcast_parameters(param_dict, root_rank=0)

    num_train_steps = args.num_train_steps
    if args.warmup_steps is not None:
        warmup_steps = args.warmup_steps
    else:
        warmup_steps = int(num_train_steps * args.warmup_ratio)
    assert warmup_steps is not None, 'Must specify either warmup_steps or warmup_ratio'
    log_interval = args.log_interval
    save_interval = args.save_interval if args.save_interval is not None\
        else num_train_steps // 50
    logging.info(
        '#Total Training Steps={}, Warmup={}, Save Interval={}'.format(
            num_train_steps, warmup_steps, save_interval))

    lr_scheduler = PolyScheduler(max_update=num_train_steps,
                                 base_lr=args.lr,
                                 warmup_begin_lr=0,
                                 pwr=1,
                                 final_lr=0,
                                 warmup_steps=warmup_steps,
                                 warmup_mode='linear')
    optimizer_params = {
        'learning_rate': args.lr,
        'wd': args.wd,
        'lr_scheduler': lr_scheduler,
    }
    if args.optimizer == 'adamw':
        optimizer_params.update({
            'beta1': 0.9,
            'beta2': 0.999,
            'epsilon': 1e-6,
            'correct_bias': False,
        })
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer,
                                         optimizer_params)
    else:
        trainer = mx.gluon.Trainer(param_dict,
                                   args.optimizer,
                                   optimizer_params,
                                   update_on_kvstore=False)
    if args.start_step:
        logging.info('Restart training from {}'.format(args.start_step))
        # TODO(zheyuye), How about data splitting, where to start re-training
        state_path = states_option(args.start_step, trainer, args.output_dir,
                                   local_rank, 'Loading')
        param_path = parameters_option(args.start_step, model, args.output_dir,
                                       'Loading')

    # prepare the loss function
    mlm_loss_fn = mx.gluon.loss.SoftmaxCELoss()
    rtd_loss_fn = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
    mlm_loss_fn.hybridize()
    rtd_loss_fn.hybridize()

    # prepare the records writer
    writer = None
    # only one process on each worker will write the tensorboardX's records to avoid race
    if args.do_eval and local_rank == 0:
        from tensorboardX import SummaryWriter
        record_path = os.path.join(args.output_dir, 'records')
        logging.info('Evaluation records saved in {}'.format(record_path))
        writer = SummaryWriter(record_path)

    step_num = args.start_step
    finish_flag = False

    log_total_loss = 0
    log_mlm_loss = 0
    log_rtd_loss = 0
    log_sample_num = 0
    train_start_time = time.time()

    # start training
    train_loop_dataloader = grouper(repeat(data_train), len(ctx_l))
    while step_num < num_train_steps:
        tic = time.time()
        for accum_idx in range(num_accumulated):
            sample_l = next(train_loop_dataloader)
            loss_l = []
            mlm_loss_l = []
            rtd_loss_l = []
            for sample, ctx in zip(sample_l, ctx_l):
                if sample is None:
                    continue
                # prepare data
                input_ids, segment_ids, valid_lengths = sample
                input_ids = input_ids.as_in_ctx(ctx)
                segment_ids = segment_ids.as_in_ctx(ctx)
                valid_lengths = valid_lengths.as_in_ctx(ctx)
                masked_input = data_masker.dynamic_masking(
                    mx.nd, input_ids, valid_lengths)
                masked_input_ids = masked_input.input_ids
                length_masks = masked_input.masks
                unmasked_tokens = masked_input.unmasked_tokens
                masked_positions = masked_input.masked_positions
                masked_weights = masked_input.masked_weights

                log_sample_num += len(masked_input_ids)

                with mx.autograd.record():
                    mlm_scores, rtd_scores, corrupted_tokens, labels = model(
                        masked_input_ids, segment_ids, valid_lengths,
                        unmasked_tokens, masked_positions)
                    denominator = (masked_weights.sum() +
                                   1e-6) * num_accumulated * len(ctx_l)
                    mlm_loss = mlm_loss_fn(
                        mx.npx.reshape(mlm_scores, (-5, -1)),
                        unmasked_tokens.reshape(
                            (-1, )), masked_weights.reshape(
                                (-1, 1))).sum() / denominator
                    denominator = (length_masks.sum() +
                                   1e-6) * num_accumulated * len(ctx_l)
                    rtd_loss = rtd_loss_fn(rtd_scores, labels,
                                           length_masks).sum() / denominator
                    output = ElectraOutput(
                        mlm_scores=mlm_scores,
                        rtd_scores=rtd_scores,
                        rtd_labels=labels,
                        corrupted_tokens=corrupted_tokens,
                    )
                    mlm_loss_l.append(mlm_loss)
                    rtd_loss_l.append(rtd_loss)
                    loss = (args.gen_weight * mlm_loss +
                            args.disc_weight * rtd_loss)
                    loss_l.append(loss)

            for loss in loss_l:
                loss.backward()
            # All Reduce the Step Loss
            log_mlm_loss += sum(
                [ele.as_in_ctx(ctx_l[0]) for ele in mlm_loss_l]).asnumpy()
            log_rtd_loss += sum(
                [ele.as_in_ctx(ctx_l[0]) for ele in rtd_loss_l]).asnumpy()
            log_total_loss += sum([ele.as_in_ctx(ctx_l[0])
                                   for ele in loss_l]).asnumpy()

        # update
        trainer.allreduce_grads()

        total_norm, ratio, is_finite = clip_grad_global_norm(
            params, args.max_grad_norm * num_workers)

        if args.comm_backend == 'horovod':
            # Note that horovod.trainer._scale is default to num_workers,
            # thus trainer.update(1) will scale the gradients by 1./num_workers
            trainer.update(1, ignore_stale_grad=True)
        else:
            # gluon.trainer._scale is default to 1
            trainer.update(num_workers, ignore_stale_grad=True)

        total_norm = total_norm / num_workers
        step_num += 1
        if num_accumulated > 1:
            # set grad to zero for gradient accumulation
            model.zero_grad()

        # saving
        if step_num % save_interval == 0 or step_num >= num_train_steps:
            if is_master_node:
                states_option(step_num, trainer, args.output_dir, local_rank,
                              'Saving')
                if local_rank == 0:
                    param_path = parameters_option(step_num, model,
                                                   args.output_dir, 'Saving')

        # logging
        if step_num % log_interval == 0:
            # Output the loss of per step
            log_mlm_loss /= log_interval
            log_rtd_loss /= log_interval
            log_total_loss /= log_interval
            toc = time.time()
            logging.info('[step {}], Loss mlm/rtd/total={:.4f}/{:.4f}/{:.4f},'
                         ' LR={:.6f}, grad_norm={:.4f}. Time cost={:.2f},'
                         ' Throughput={:.2f} samples/s, ETA={:.2f}h'.format(
                             step_num, log_mlm_loss, log_rtd_loss,
                             log_total_loss, trainer.learning_rate, total_norm,
                             toc - tic, log_sample_num / (toc - tic),
                             (num_train_steps - step_num) /
                             (step_num / (toc - train_start_time)) / 3600))
            tic = time.time()

            if args.do_eval:
                evaluation(writer, step_num, masked_input, output)
                if writer is not None:
                    writer.add_scalars(
                        'loss', {
                            'total_loss': log_total_loss,
                            'mlm_loss': log_mlm_loss,
                            'rtd_loss': log_rtd_loss
                        }, step_num)
            log_mlm_loss = 0
            log_rtd_loss = 0
            log_total_loss = 0
            log_sample_num = 0

    logging.info('Finish training step: %d', step_num)
    if is_master_node:
        state_path = states_option(step_num, trainer, args.output_dir,
                                   local_rank, 'Saving')
        if local_rank == 0:
            param_path = parameters_option(step_num, model, args.output_dir,
                                           'Saving')

    mx.npx.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f}s'.format(train_end_time -
                                             train_start_time))

    if writer is not None:
        writer.close()

    if local_rank == 0:
        model_name = args.model_name.replace('google', 'gluon')
        save_dir = os.path.join(args.output_dir, model_name)
        final_save(model, save_dir, tokenizer)