def get_lr_scheduler(args, train_loader):
    if args.optim_phase == 'Factor':
        every_lr_decay_step = args.every_lr_decay_step
        lr_scheduler = FactorScheduler(step=every_lr_decay_step, factor=0.1)
    elif args.optim_phase == 'MultiFactor':
        lr_decay_steps = [
            len(train_loader) * ep for ep in args.lr_decay_epochs
        ]
        lr_scheduler = MultiFactorScheduler(step=lr_decay_steps, factor=0.1)
    elif args.optim_phase == 'Poly':
        max_update_step = args.epochs
        lr_scheduler = PolyScheduler(max_update=max_update_step)
    elif args.optim_phase == 'Cosine':
        max_update_step = args.epochs
        lr_scheduler = CosineScheduler(max_update=max_update_step)
    else:
        raise ValueError('Invalid phase {}'.format(args.optim_phase))
    return lr_scheduler
示例#2
0
def get_optimizer(cfg, updates_per_epoch):
    max_update = int(updates_per_epoch * cfg.num_train_epochs)
    warmup_steps = int(updates_per_epoch * cfg.num_train_epochs *
                       cfg.warmup_portion)
    if cfg.lr_scheduler == 'triangular':
        assert warmup_steps < max_update
        lr_scheduler = PolyScheduler(max_update=max_update,
                                     base_lr=cfg.lr,
                                     warmup_begin_lr=cfg.begin_lr,
                                     pwr=1,
                                     final_lr=cfg.final_lr,
                                     warmup_steps=warmup_steps,
                                     warmup_mode='linear')
    elif cfg.lr_scheduler == 'inv_sqrt':
        warmup_steps = int(updates_per_epoch * cfg.num_train_epochs *
                           cfg.warmup_portion)
        lr_scheduler = InverseSquareRootScheduler(warmup_steps=warmup_steps,
                                                  base_lr=cfg.lr,
                                                  warmup_init_lr=cfg.begin_lr)
    elif cfg.lr_scheduler == 'constant':
        lr_scheduler = None
    elif cfg.lr_scheduler == 'cosine':
        max_update = int(updates_per_epoch * cfg.num_train_epochs)
        warmup_steps = int(updates_per_epoch * cfg.num_train_epochs *
                           cfg.warmup_portion)
        assert warmup_steps < max_update
        lr_scheduler = CosineScheduler(max_update=max_update,
                                       base_lr=cfg.lr,
                                       final_lr=cfg.final_lr,
                                       warmup_steps=warmup_steps,
                                       warmup_begin_lr=cfg.begin_lr)
    else:
        raise ValueError('Unsupported lr_scheduler="{}"'.format(
            cfg.lr_scheduler))
    optimizer_params = {
        'learning_rate': cfg.lr,
        'wd': cfg.wd,
        'lr_scheduler': lr_scheduler
    }
    optimizer = cfg.optimizer
    additional_params = {key: value for key, value in cfg.optimizer_params}
    optimizer_params.update(additional_params)
    return optimizer, optimizer_params, max_update
示例#3
0
def train(args):
    store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    setup_logging(args, local_rank)
    cfg, tokenizer, qa_net, use_segmentation = \
        get_network(args.model_name, ctx_l,
                    args.classifier_dropout,
                    args.param_checkpoint,
                    args.backbone_path)

    logging.info('Prepare training data')
    train_features = get_squad_features(args, tokenizer, segment='train')
    dataset_processor = SquadDatasetProcessor(
        tokenizer=tokenizer,
        doc_stride=args.doc_stride,
        max_seq_length=args.max_seq_length,
        max_query_length=args.max_query_length)
    logging.info('Processing the Training data:')
    train_dataset, num_answer_mismatch, num_unreliable \
        = dataset_processor.get_train(train_features, skip_unreliable=True)
    logging.info(
        'Done! #Unreliable Span={} / #Mismatched Answer={} / #Total={}'.format(
            num_unreliable, num_answer_mismatch, len(train_features)))

    # Get dataset statistics
    num_impossible = 0
    for sample in train_dataset:
        num_impossible += sample.is_impossible
    logging.info('Before Chunking, #Train/Is Impossible = {}/{}'.format(
        len(train_features),
        sum([ele.is_impossible for ele in train_features])))
    logging.info('After Chunking, #Train Sample/Is Impossible = {}/{}'.format(
        len(train_dataset), num_impossible))

    # Shuffle the dataset using a fixed seed across all workers
    rs = np.random.RandomState(args.pre_shuffle_seed)
    rs.shuffle(train_dataset)
    sampler = SplitSampler(len(train_dataset),
                           num_parts=num_workers,
                           part_index=rank,
                           even_size=True)
    train_dataloader = mx.gluon.data.DataLoader(
        train_dataset,
        batchify_fn=dataset_processor.BatchifyFunction,
        batch_size=args.batch_size,
        num_workers=0,
        sampler=sampler)
    if 'electra' in args.model_name:
        # Froze parameters, does not work for albert model since parameters in all layers are shared
        if args.untunable_depth > 0:
            qa_net.backbone.frozen_params(args.untunable_depth)
        if args.layerwise_decay > 0:
            qa_net.backbone.apply_layerwise_decay(args.layerwise_decay)

    logging.info('Creating distributed trainer...')
    # Collect differentiable parameters
    param_dict = qa_net.collect_params()
    # Do not apply weight decay to all the LayerNorm and bias
    for _, v in qa_net.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    params = [p for p in param_dict.values() if p.grad_req != 'null']
    # Set grad_req if gradient accumulation is required
    num_accumulated = args.num_accumulated
    if num_accumulated > 1:
        logging.info(
            'Using gradient accumulation. Effective global batch size = {}'.
            format(num_accumulated * args.batch_size * len(ctx_l) *
                   num_workers))
        for p in params:
            p.grad_req = 'add'
    # backend specific implementation
    if args.comm_backend == 'horovod':
        # Horovod: fetch and broadcast parameters
        hvd.broadcast_parameters(param_dict, root_rank=0)

    epoch_size = (len(train_dataloader) + len(ctx_l) - 1) // len(ctx_l)
    if args.num_train_steps is not None:
        num_train_steps = args.num_train_steps
    else:
        num_train_steps = int(args.epochs * epoch_size / args.num_accumulated)
    if args.warmup_steps is not None:
        warmup_steps = args.warmup_steps
    else:
        warmup_steps = int(num_train_steps * args.warmup_ratio)
    assert warmup_steps is not None, 'Must specify either warmup_steps or warmup_ratio'
    log_interval = args.log_interval
    save_interval = args.save_interval if args.save_interval is not None\
        else epoch_size // args.num_accumulated
    logging.info(
        '#Total Training Steps={}, Warmup={}, Save Interval={}'.format(
            num_train_steps, warmup_steps, save_interval))

    # set up optimization
    lr_scheduler = PolyScheduler(max_update=num_train_steps,
                                 base_lr=args.lr,
                                 warmup_begin_lr=0,
                                 pwr=1,
                                 final_lr=0,
                                 warmup_steps=warmup_steps,
                                 warmup_mode='linear')
    optimizer_params = {
        'learning_rate': args.lr,
        'wd': args.wd,
        'lr_scheduler': lr_scheduler,
    }
    adam_betas = eval(args.adam_betas)
    if args.optimizer == 'adamw':
        optimizer_params.update({
            'beta1': adam_betas[0],
            'beta2': adam_betas[1],
            'epsilon': args.adam_epsilon,
            'correct_bias': False,
        })
    elif args.optimizer == 'adam':
        optimizer_params.update({
            'beta1': adam_betas[0],
            'beta2': adam_betas[1],
            'epsilon': args.adam_epsilon,
        })
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer,
                                         optimizer_params)
    else:
        trainer = mx.gluon.Trainer(param_dict,
                                   args.optimizer,
                                   optimizer_params,
                                   update_on_kvstore=False)

    log_span_loss = 0
    log_answerable_loss = 0
    log_total_loss = 0
    log_sample_num = 0

    global_tic = time.time()
    tic = time.time()
    for step_num, batch_data in enumerate(
            grouper(repeat(train_dataloader),
                    len(ctx_l) * num_accumulated)):
        for sample_l in grouper(batch_data, len(ctx_l)):
            loss_l = []
            span_loss_l = []
            answerable_loss_l = []
            for sample, ctx in zip(sample_l, ctx_l):
                if sample is None:
                    continue
                # Copy the data to device
                tokens = sample.data.as_in_ctx(ctx)
                log_sample_num += len(tokens)
                segment_ids = sample.segment_ids.as_in_ctx(
                    ctx) if use_segmentation else None
                valid_length = sample.valid_length.as_in_ctx(ctx)
                p_mask = sample.masks.as_in_ctx(ctx)
                gt_start = sample.gt_start.as_in_ctx(ctx).astype(np.int32)
                gt_end = sample.gt_end.as_in_ctx(ctx).astype(np.int32)
                is_impossible = sample.is_impossible.as_in_ctx(ctx).astype(
                    np.int32)
                batch_idx = mx.np.arange(tokens.shape[0],
                                         dtype=np.int32,
                                         ctx=ctx)
                p_mask = 1 - p_mask  # In the network, we use 1 --> no_mask, 0 --> mask
                with mx.autograd.record():
                    start_logits, end_logits, answerable_logits \
                        = qa_net(tokens, segment_ids, valid_length, p_mask, gt_start)
                    sel_start_logits = start_logits[batch_idx, gt_start]
                    sel_end_logits = end_logits[batch_idx, gt_end]
                    sel_answerable_logits = answerable_logits[batch_idx,
                                                              is_impossible]
                    span_loss = -0.5 * (sel_start_logits +
                                        sel_end_logits).mean()
                    answerable_loss = -0.5 * sel_answerable_logits.mean()
                    loss = span_loss + answerable_loss
                    loss_l.append(loss)
                    span_loss_l.append(span_loss)
                    answerable_loss_l.append(answerable_loss)

            for loss in loss_l:
                loss.backward()
            # All Reduce the Step Loss
            log_span_loss += sum(
                [ele.as_in_ctx(ctx_l[0]) for ele in span_loss_l]).asnumpy()
            log_total_loss += sum([ele.as_in_ctx(ctx_l[0])
                                   for ele in loss_l]).asnumpy()
            log_answerable_loss += sum([
                ele.as_in_ctx(ctx_l[0]) for ele in answerable_loss_l
            ]).asnumpy()
        # update
        trainer.allreduce_grads()

        if args.max_grad_norm > 0:
            total_norm, ratio, is_finite = clip_grad_global_norm(
                params, args.max_grad_norm * num_workers)
        else:
            total_norm = grad_global_norm(params)

        if args.comm_backend == 'horovod':
            # Note that horovod.trainer._scale is default to num_workers,
            # thus trainer.update(1) will scale the gradients by 1./num_workers
            trainer.update(1, ignore_stale_grad=True)
        else:
            # gluon.trainer._scale is default to 1
            trainer.update(num_workers, ignore_stale_grad=True)

        total_norm = total_norm / num_workers
        if args.num_accumulated > 1:
            # set grad to zero for gradient accumulation
            qa_net.zero_grad()

        # saving
        if local_rank == 0 and (step_num + 1) % save_interval == 0 or (
                step_num + 1) >= num_train_steps:
            version_prefix = 'squad' + args.version
            ckpt_name = '{}_{}_{}.params'.format(args.model_name,
                                                 version_prefix,
                                                 (step_num + 1))
            params_saved = os.path.join(args.output_dir, ckpt_name)
            qa_net.save_parameters(params_saved)
            ckpt_candidates = [
                f for f in os.listdir(args.output_dir) if f.endswith('.params')
            ]
            # keep last `max_saved_ckpt` checkpoints
            if len(ckpt_candidates) > args.max_saved_ckpt:
                ckpt_candidates.sort(key=lambda ele: (len(ele), ele))
                os.remove(os.path.join(args.output_dir, ckpt_candidates[0]))
            logging.info('Params saved in: {}'.format(params_saved))

        # logging
        if (step_num + 1) % log_interval == 0:
            log_span_loss /= log_sample_num
            log_answerable_loss /= log_sample_num
            log_total_loss /= log_sample_num
            toc = time.time()
            logging.info(
                'Step: {}/{}, Loss span/answer/total={:.4f}/{:.4f}/{:.4f},'
                ' LR={:.8f}, grad_norm={:.4f}. Time cost={:.2f}, Throughput={:.2f} samples/s'
                ' ETA={:.2f}h'.format(
                    (step_num + 1), num_train_steps, log_span_loss,
                    log_answerable_loss, log_total_loss, trainer.learning_rate,
                    total_norm, toc - tic, log_sample_num / (toc - tic),
                    (num_train_steps - (step_num + 1)) /
                    ((step_num + 1) / (toc - global_tic)) / 3600))
            tic = time.time()
            log_span_loss = 0
            log_answerable_loss = 0
            log_total_loss = 0
            log_sample_num = 0
            num_samples_per_update = 0

        if (step_num + 1) >= num_train_steps:
            toc = time.time()
            logging.info('Finish training step: {} within {} hours'.format(
                step_num + 1, (toc - global_tic) / 3600))
            break

    return params_saved
def train(args):
    store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    task = get_task(args.task_name)
    #setup_logging(args, local_rank)
    level = logging.INFO
    detail_dir = os.path.join(args.output_dir, args.task_name)
    if not os.path.exists(detail_dir):
        os.mkdir(detail_dir)
    logging_config(
        detail_dir,
        name='train_{}_{}_'.format(args.task_name, args.model_name) +
        str(rank),  # avoid race
        level=level,
        console=(local_rank == 0))
    logging.info(args)
    cfg, tokenizer, classify_net, use_segmentation = \
        get_network(args.model_name, ctx_l,
                    args.param_checkpoint,
                    args.backbone_path,
                    task)
    logging.info('Prepare training data')
    train_data, _ = get_task_data(args, tokenizer, segment='train', task=task)
    train_batchify = bf.Group(bf.Group(bf.Pad(), bf.Pad(), bf.Stack()),
                              bf.Stack())

    epoch_num_updates = len(train_data) // args.batch_size
    max_update = epoch_num_updates * args.epochs
    warmup_steps = int(np.ceil(max_update * args.warmup_ratio))

    dataloader = DataLoader(train_data,
                            batch_size=args.batch_size,
                            batchify_fn=train_batchify,
                            num_workers=4,
                            shuffle=True)
    dataloader = grouper(repeat(dataloader), len(ctx_l))

    param_dict = classify_net.collect_params()
    # Do not apply weight decay to all the LayerNorm and bias
    for _, v in classify_net.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    # Set grad_req if gradient accumulation is required
    params = [p for p in param_dict.values() if p.grad_req != 'null']
    num_accumulated = args.num_accumulated
    if num_accumulated > 1:
        logging.info(
            'Using gradient accumulation. Effective global batch size = {}'.
            format(num_accumulated * args.batch_size * len(ctx_l) *
                   num_workers))
        for p in params:
            p.grad_req = 'add'

    if args.comm_backend == 'horovod':
        # Horovod: fetch and broadcast parameters
        hvd.broadcast_parameters(param_dict, root_rank=0)

    lr_scheduler = PolyScheduler(max_update=max_update,
                                 base_lr=args.lr,
                                 warmup_begin_lr=0.0,
                                 pwr=1,
                                 final_lr=0.0,
                                 warmup_steps=warmup_steps,
                                 warmup_mode='linear')
    optimizer_params = {
        'learning_rate': args.lr,
        'wd': args.wd,
        'lr_scheduler': lr_scheduler
    }
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer,
                                         optimizer_params)
    else:
        trainer = mx.gluon.Trainer(classify_net.collect_params(), 'adamw',
                                   optimizer_params)

    if args.task_name == 'sts':
        loss_function = gluon.loss.L2Loss()
    else:
        loss_function = gluon.loss.SoftmaxCELoss()

    #prepare loss function
    log_loss = 0
    log_gnorm = 0
    log_step = 0
    if args.log_interval > 0:
        log_interval = args.log_interval
    else:
        log_interval = int(epoch_num_updates * 0.5)

    for i in range(max_update):
        sample_l = next(dataloader)
        loss_l = []
        for sample, ctx in zip(sample_l, ctx_l):
            (token_ids, token_types, valid_length), label = sample
            # Move to the corresponding context
            token_ids = mx.np.array(token_ids, ctx=ctx)
            token_types = mx.np.array(token_types, ctx=ctx)
            valid_length = mx.np.array(valid_length, ctx=ctx)
            label = mx.np.array(label, ctx=ctx)
            with mx.autograd.record():
                scores = classify_net(token_ids, token_types, valid_length)
                loss = loss_function(scores, label).mean() / len(ctx_l)
                loss_l.append(loss)
        for loss in loss_l:
            loss.backward()
        trainer.allreduce_grads()
        # Begin Norm Clipping
        total_norm, ratio, is_finite = clip_grad_global_norm(
            params, args.max_grad_norm)
        trainer.update(1.0)
        step_loss = sum([loss.asnumpy() for loss in loss_l])
        log_loss += step_loss
        log_gnorm += total_norm
        log_step += 1
        if log_step >= log_interval or i == max_update - 1:
            logging.info(
                '[Iter {} / {}] avg {} = {:.2f}, avg gradient norm = {:.2f}'.
                format(i + 1, max_update, 'nll', log_loss / log_step,
                       log_gnorm / log_step))
            log_loss = 0
            log_gnorm = 0
            log_step = 0
        if local_rank == 0 and (i == max_update - 1 or i %
                                (max_update // args.epochs) == 0 and i > 0):
            ckpt_name = '{}_{}_{}.params'.format(args.model_name,
                                                 args.task_name, (i + 1))

            params_saved = os.path.join(detail_dir, ckpt_name)
            classify_net.save_parameters(params_saved)
            logging.info('Params saved in: {}'.format(params_saved))
示例#5
0
def train(args):
    _, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    level = logging.DEBUG if args.verbose else logging.INFO
    logging_config(args.ckpt_dir,
                   name='pretrain_bert_' + str(rank),  # avoid race
                   level=level,
                   console=(local_rank == 0))
    logging.info(args)
    logging.debug('Random seed set to {}'.format(args.seed))
    set_seed(args.seed)
    logging.info('Training info: num_buckets: {}, '
                 'num_workers: {}, rank: {}'.format(
                     args.num_buckets, num_workers, rank))
    cfg, tokenizer, model = get_pretraining_model(args.model_name, ctx_l, args.max_seq_length)

    if args.raw:
        get_dataset_fn = functools.partial(get_pretrain_data_text,
                                           max_seq_length=args.max_seq_length,
                                           short_seq_prob=args.short_seq_prob,
                                           masked_lm_prob=args.masked_lm_prob,
                                           max_predictions_per_seq=args.max_predictions_per_seq,
                                           whole_word_mask=args.whole_word_mask,
                                           random_next_sentence=args.random_next_sentence,
                                           tokenizer=tokenizer,
                                           circle_length=args.circle_length,
                                           repeat=args.repeat,
                                           dataset_cached=args.dataset_cached,
                                           num_max_dataset_cached=args.num_max_dataset_cached)
    else:
        get_dataset_fn = get_pretrain_data_npz

    data_train = get_dataset_fn(args.data, args.batch_size, shuffle=True,
                                num_buckets=args.num_buckets, vocab=tokenizer.vocab,
                                num_parts=num_workers, part_idx=rank,
                                num_dataset_workers=args.num_dataset_workers,
                                num_batch_workers=args.num_batch_workers)

    param_dict = model.collect_params()
    # Do not apply weight decay to all the LayerNorm and bias
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    # Set grad_req if gradient accumulation is required
    params = [p for p in param_dict.values() if p.grad_req != 'null']
    num_accumulated = args.num_accumulated
    if num_accumulated > 1:
        logging.info('Using gradient accumulation. Effective global batch size = {}'
                     .format(num_accumulated * args.batch_size * len(ctx_l) * num_workers))
        for p in params:
            p.grad_req = 'add'

    num_steps = args.num_steps
    warmup_steps = int(num_steps * args.warmup_ratio)
    log_interval = args.log_interval
    save_interval = args.ckpt_interval
    logging.info('#Total Training Steps={}, Warmup Steps={}, Save Interval={}'
                 .format(num_steps, warmup_steps, save_interval))
    lr_scheduler = PolyScheduler(max_update=num_steps,
                                 base_lr=args.lr,
                                 warmup_begin_lr=0,
                                 pwr=1,
                                 final_lr=0,
                                 warmup_steps=warmup_steps,
                                 warmup_mode='linear')
    optimizer_params = {'learning_rate': args.lr,
                        'wd': args.wd,
                        'lr_scheduler': lr_scheduler,
                        }
    if args.optimizer == 'adamw':
        optimizer_params.update({'beta1': 0.9,
                                 'beta2': 0.999,
                                 'epsilon': 1e-6,
                                 'correct_bias': False,
                                 })
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params)
    elif args.comm_backend == 'byteps':
        trainer = bps.DistributedTrainer(param_dict, args.optimizer, optimizer_params)
    else:
        trainer = mx.gluon.Trainer(param_dict, args.optimizer, optimizer_params,
                                   update_on_kvstore=False)
    if args.start_step:
        logging.info('Restart training from {}'.format(args.start_step))
        parameters_option(args.start_step, model, args.ckpt_dir, 'Loading')
        states_option(args.start_step, trainer, args.ckpt_dir, local_rank, 'Loading')

    if args.comm_backend == 'byteps':
        trainer._init_params()
    # backend specific implementation
    if args.comm_backend == 'horovod':
        # Horovod: fetch and broadcast parameters
        hvd.broadcast_parameters(param_dict, root_rank=0)

    # prepare the loss function
    nsp_loss_fn = mx.gluon.loss.SoftmaxCELoss()
    mlm_loss_fn = mx.gluon.loss.SoftmaxCELoss()
    nsp_loss_fn.hybridize()
    mlm_loss_fn.hybridize()

    mlm_metric = MaskedAccuracy()
    nsp_metric = MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    step_num = args.start_step
    running_mlm_loss, running_nsp_loss = 0., 0.
    running_num_tks = 0

    train_start_time = time.time()
    tic = time.time()
    # start training
    train_loop_dataloader = grouper(repeat(data_train), len(ctx_l))
    while step_num < num_steps:
        for _ in range(num_accumulated):
            sample_l = next(train_loop_dataloader)
            mlm_loss_l = []
            nsp_loss_l = []
            loss_l = []
            ns_label_list, ns_pred_list = [], []
            mask_label_list, mask_pred_list, mask_weight_list = [], [], []
            for sample, ctx in zip(sample_l, ctx_l):
                # prepare data
                (input_id, masked_id, masked_position, masked_weight, \
                    next_sentence_label, segment_id, valid_length) = sample
                input_id = input_id.as_in_ctx(ctx)
                masked_id = masked_id.as_in_ctx(ctx)
                masked_position = masked_position.as_in_ctx(ctx)
                masked_weight = masked_weight.as_in_ctx(ctx)
                next_sentence_label = next_sentence_label.as_in_ctx(ctx)
                segment_id = segment_id.as_in_ctx(ctx)
                valid_length = valid_length.as_in_ctx(ctx)

                with mx.autograd.record():
                    _, _, nsp_score, mlm_scores = model(input_id, segment_id,
                        valid_length, masked_position)
                    denominator = (masked_weight.sum() + 1e-8) * num_accumulated * len(ctx_l)
                    mlm_scores_r = mx.npx.reshape(mlm_scores, (-5, -1))
                    masked_id_r = masked_id.reshape((-1,))
                    mlm_loss = mlm_loss_fn(
                        mlm_scores_r,
                        masked_id_r,
                        masked_weight.reshape((-1, 1))).sum() / denominator
                    denominator = num_accumulated * len(ctx_l)
                    nsp_loss = nsp_loss_fn(
                        nsp_score, next_sentence_label).mean() / denominator
                    mlm_loss_l.append(mlm_loss)
                    nsp_loss_l.append(nsp_loss)
                    loss_l.append(mlm_loss + nsp_loss)
                    mask_label_list.append(masked_id_r)
                    mask_pred_list.append(mlm_scores_r)
                    mask_weight_list.append(masked_weight.reshape((-1,)))
                    ns_label_list.append(next_sentence_label)
                    ns_pred_list.append(nsp_score)

                running_num_tks += valid_length.sum().as_in_ctx(mx.cpu())

            for loss in loss_l:
                loss.backward()

            running_mlm_loss += sum([ele.as_in_ctx(mx.cpu())
                                    for ele in mlm_loss_l]).asnumpy().item()
            running_nsp_loss += sum([ele.as_in_ctx(mx.cpu())
                                    for ele in nsp_loss_l]).asnumpy().item()
            mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list)
            nsp_metric.update(ns_label_list, ns_pred_list)
        # update
        trainer.allreduce_grads()

        total_norm, ratio, is_finite = clip_grad_global_norm(
            params, args.max_grad_norm * num_workers)
        total_norm = total_norm / num_workers

        if args.comm_backend == 'horovod' or args.comm_backend == 'byteps':
            # Note that horovod.trainer._scale is default to num_workers,
            # thus trainer.update(1) will scale the gradients by 1./num_workers
            trainer.update(1, ignore_stale_grad=True)
        else:
            # gluon.trainer._scale is default to 1
            trainer.update(num_workers, ignore_stale_grad=True)

        if num_accumulated > 1:
            # set grad to zero for gradient accumulation
            model.zero_grad()

        step_num += 1
        # saving
        if step_num % save_interval == 0 or step_num >= num_steps:
            states_option(step_num, trainer, args.ckpt_dir, local_rank, 'Saving')
            if local_rank == 0:
                parameters_option(step_num, model, args.ckpt_dir, 'Saving')
        # logging
        if step_num % log_interval == 0:
            running_mlm_loss /= log_interval
            running_nsp_loss /= log_interval
            toc = time.time()
            logging.info(
                '[step {}], Loss mlm/nsp={:.5f}/{:.3f}, Acc mlm/nsp={:.3f}/{:.3f}, '
                ' LR={:.7f}, grad_norm={:.4f}. Time cost={:.2f} s,'
                ' Throughput={:.1f}K tks/s, ETA={:.2f} h'.format(
                    step_num, running_mlm_loss, running_nsp_loss,
                    mlm_metric.get()[1], nsp_metric.get()[1],
                    trainer.learning_rate, total_norm, toc - tic,
                    running_num_tks.asnumpy().item() / (toc - tic) / 1000,
                    (num_steps - step_num) / (step_num / (toc - train_start_time)) / 3600))
            mlm_metric.reset()
            nsp_metric.reset()
            tic = time.time()

            running_mlm_loss = 0
            running_nsp_loss = 0
            running_num_tks = 0

    logging.info('Finish training step: %d', step_num)

    mx.npx.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f} s'.format(train_end_time - train_start_time))

    if local_rank == 0:
        model_name = args.model_name.replace('google', 'gluon')
        save_dir = os.path.join(args.ckpt_dir, model_name)
        final_save(model, save_dir, tokenizer, cfg)
示例#6
0
def train(args):
    store, num_workers, rank, local_rank, is_master_node, ctx_l = init_comm(
        args.comm_backend, args.gpus)
    logging.info('Training info: num_buckets: {}, '
                 'num_workers: {}, rank: {}'.format(args.num_buckets,
                                                    num_workers, rank))
    cfg, tokenizer, model = get_pretraining_model(args.model_name, ctx_l,
                                                  args.max_seq_length,
                                                  args.hidden_dropout_prob,
                                                  args.attention_dropout_prob,
                                                  args.generator_units_scale,
                                                  args.generator_layers_scale)
    data_masker = ElectraMasker(tokenizer, args.max_seq_length, args.mask_prob)
    if args.from_raw_text:
        if args.cached_file_path and not os.path.exists(args.cached_file_path):
            os.mkdir(args.cached_file_path)
        get_dataset_fn = functools.partial(
            get_pretrain_data_text,
            max_seq_length=args.max_seq_length,
            short_seq_prob=args.short_seq_prob,
            tokenizer=tokenizer,
            circle_length=args.circle_length,
            repeat=args.repeat,
            cached_file_path=args.cached_file_path)

        logging.info(
            'Processing and loading the training dataset from raw text.')

    else:
        logging.info('Loading the training dataset from local Numpy file.')
        get_dataset_fn = get_pretrain_data_npz

    data_train = get_dataset_fn(args.data,
                                args.batch_size,
                                shuffle=True,
                                num_buckets=args.num_buckets,
                                vocab=tokenizer.vocab,
                                num_parts=num_workers,
                                part_idx=rank,
                                num_dataset_workers=args.num_dataset_workers,
                                num_batch_workers=args.num_batch_workers)

    logging.info('Creating distributed trainer...')
    param_dict = model.collect_params()
    # Do not apply weight decay to all the LayerNorm and bias
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    # Collect differentiable parameters
    params = [p for p in param_dict.values() if p.grad_req != 'null']
    # Set grad_req if gradient accumulation is required
    if args.num_accumulated > 1:
        logging.info(
            'Using gradient accumulation. Effective global batch size = {}'.
            format(args.num_accumulated * args.batch_size * len(ctx_l) *
                   num_workers))
        for p in params:
            p.grad_req = 'add'
    # backend specific implementation
    if args.comm_backend == 'horovod':
        # Horovod: fetch and broadcast parameters
        hvd.broadcast_parameters(param_dict, root_rank=0)

    num_train_steps = args.num_train_steps
    if args.warmup_steps is not None:
        warmup_steps = args.warmup_steps
    else:
        warmup_steps = int(num_train_steps * args.warmup_ratio)
    assert warmup_steps is not None, 'Must specify either warmup_steps or warmup_ratio'
    log_interval = args.log_interval
    save_interval = args.save_interval if args.save_interval is not None\
        else num_train_steps // 50
    logging.info(
        '#Total Training Steps={}, Warmup={}, Save Interval={}'.format(
            num_train_steps, warmup_steps, save_interval))

    lr_scheduler = PolyScheduler(max_update=num_train_steps,
                                 base_lr=args.lr,
                                 warmup_begin_lr=0,
                                 pwr=1,
                                 final_lr=0,
                                 warmup_steps=warmup_steps,
                                 warmup_mode='linear')
    optimizer_params = {
        'learning_rate': args.lr,
        'wd': args.wd,
        'lr_scheduler': lr_scheduler,
    }
    if args.optimizer == 'adamw':
        optimizer_params.update({
            'beta1': 0.9,
            'beta2': 0.999,
            'epsilon': 1e-6,
            'correct_bias': False,
        })
    if args.comm_backend == 'horovod':
        trainer = hvd.DistributedTrainer(param_dict, args.optimizer,
                                         optimizer_params)
    else:
        trainer = mx.gluon.Trainer(param_dict,
                                   args.optimizer,
                                   optimizer_params,
                                   update_on_kvstore=False)
    if args.start_step:
        logging.info('Restart training from {}'.format(args.start_step))
        # TODO(zheyuye), How about data splitting, where to start re-training
        state_path = states_option(args.start_step, trainer, args.output_dir,
                                   local_rank, 'Loading')
        param_path = parameters_option(args.start_step, model, args.output_dir,
                                       'Loading')

    # prepare the loss function
    mlm_loss_fn = mx.gluon.loss.SoftmaxCELoss()
    rtd_loss_fn = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
    mlm_loss_fn.hybridize()
    rtd_loss_fn.hybridize()

    # prepare the records writer
    writer = None
    if args.do_eval and local_rank == 0:
        from tensorboardX import SummaryWriter
        record_path = os.path.join(args.output_dir, 'records')
        logging.info('Evaluation records saved in {}'.format(record_path))
        writer = SummaryWriter(record_path)

    step_num = args.start_step
    finish_flag = False
    num_samples_per_update = 0
    loss_denom = float(len(ctx_l) * args.num_accumulated * num_workers)

    log_total_loss = 0
    log_mlm_loss = 0
    log_rtd_loss = 0
    log_sample_num = 0
    train_start_time = time.time()
    if args.num_accumulated != 1:
        # set grad to zero for gradient accumulation
        model.zero_grad()

    # start training
    train_loop_dataloader = grouper(repeat(data_train), len(ctx_l))
    while step_num < num_train_steps:
        tic = time.time()
        for accum_idx in range(args.num_accumulated):
            sample_l = next(train_loop_dataloader)
            loss_l = []
            mlm_loss_l = []
            rtd_loss_l = []
            for sample, ctx in zip(sample_l, ctx_l):
                if sample is None:
                    continue
                # prepare data
                input_ids, segment_ids, valid_lengths = sample
                input_ids = input_ids.as_in_ctx(ctx)
                segment_ids = segment_ids.as_in_ctx(ctx)
                valid_lengths = valid_lengths.as_in_ctx(ctx)
                masked_input = data_masker.dynamic_masking(
                    mx.nd, input_ids, valid_lengths)
                masked_input_ids = masked_input.input_ids
                length_masks = masked_input.masks
                unmasked_tokens = masked_input.unmasked_tokens
                masked_positions = masked_input.masked_positions
                masked_weights = masked_input.masked_weights

                log_sample_num += len(masked_input_ids)
                num_samples_per_update += len(masked_input_ids)

                with mx.autograd.record():
                    mlm_scores, rtd_scores, corrupted_tokens, labels = model(
                        masked_input_ids, segment_ids, valid_lengths,
                        unmasked_tokens, masked_positions)
                    # the official implementation takes the sum of each batch inside the loss function
                    # while SigmoidBinaryCrossEntropyLoss and SoftmaxCELoss takes the mean value
                    mlm_loss = mlm_loss_fn(
                        mlm_scores, unmasked_tokens, masked_weights.reshape(
                            -1)).mean() / (masked_weights.mean() + 1e-6)
                    rtd_loss = rtd_loss_fn(
                        rtd_scores, labels,
                        length_masks).mean() / (length_masks.mean() + 1e-6)
                    output = ElectraOutput(
                        mlm_scores=mlm_scores,
                        rtd_scores=rtd_scores,
                        rtd_labels=labels,
                        corrupted_tokens=corrupted_tokens,
                    )
                    mlm_loss_l.append(mlm_loss)
                    rtd_loss_l.append(rtd_loss)
                    loss = (args.gen_weight * mlm_loss +
                            args.disc_weight * rtd_loss) / loss_denom
                    loss_l.append(loss)

            for loss in loss_l:
                loss.backward()
            # All Reduce the Step Loss
            log_mlm_loss += sum(
                [ele.as_in_ctx(ctx_l[0]) for ele in mlm_loss_l]).asnumpy()
            log_rtd_loss += sum(
                [ele.as_in_ctx(ctx_l[0]) for ele in rtd_loss_l]).asnumpy()
            log_total_loss += sum([ele.as_in_ctx(ctx_l[0])
                                   for ele in loss_l]).asnumpy() * loss_denom

        # update
        trainer.allreduce_grads()
        # Here, the accumulated gradients are
        # \sum_{n=1}^N g_n / loss_denom
        # Thus, in order to clip the average gradient
        #   \frac{1}{N} \sum_{n=1}^N      -->  clip to args.max_grad_norm
        # We need to change the ratio to be
        #  \sum_{n=1}^N g_n / loss_denom  -->  clip to args.max_grad_norm  * N / loss_denom
        total_norm, ratio, is_finite = clip_grad_global_norm(
            params, args.max_grad_norm * num_samples_per_update / loss_denom)
        total_norm = total_norm / (num_samples_per_update / loss_denom)
        trainer.update(num_samples_per_update / loss_denom,
                       ignore_stale_grad=True)
        step_num += 1
        if args.num_accumulated != 1:
            # set grad to zero for gradient accumulation
            model.zero_grad()

        # saving
        if step_num % save_interval == 0 or step_num >= num_train_steps:
            if is_master_node:
                states_option(step_num, trainer, args.output_dir, local_rank,
                              'Saving')
                if local_rank == 0:
                    param_path = parameters_option(step_num, model,
                                                   args.output_dir, 'Saving')

        # logging
        if step_num % log_interval == 0 and local_rank == 0:
            # Output the loss of per step
            log_mlm_loss /= log_interval
            log_rtd_loss /= log_interval
            log_total_loss /= log_interval
            toc = time.time()
            logging.info('[step {}], Loss mlm/rtd/total={:.4f}/{:.4f}/{:.4f},'
                         ' LR={:.6f}, grad_norm={:.4f}. Time cost={:.2f},'
                         ' Throughput={:.2f} samples/s, ETA={:.2f}h'.format(
                             step_num, log_mlm_loss, log_rtd_loss,
                             log_total_loss, trainer.learning_rate, total_norm,
                             toc - tic, log_sample_num / (toc - tic),
                             (num_train_steps - step_num) /
                             (step_num / (toc - train_start_time)) / 3600))
            tic = time.time()

            if args.do_eval:
                evaluation(writer, step_num, masked_input, output)
                writer.add_scalars(
                    'loss', {
                        'total_loss': log_total_loss,
                        'mlm_loss': log_mlm_loss,
                        'rtd_loss': log_rtd_loss
                    }, step_num)
            log_mlm_loss = 0
            log_rtd_loss = 0
            log_total_loss = 0
            log_sample_num = 0

        num_samples_per_update = 0

    logging.info('Finish training step: %d', step_num)
    if is_master_node:
        state_path = states_option(step_num, trainer, args.output_dir,
                                   local_rank, 'Saving')
        if local_rank == 0:
            param_path = parameters_option(step_num, model, args.output_dir,
                                           'Saving')

    mx.npx.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f}s'.format(train_end_time -
                                             train_start_time))
    if writer is not None:
        writer.close()

    if local_rank == 0:
        model_name = args.model_name.replace('google', 'gluon')
        save_dir = os.path.join(args.output_dir, model_name)
        final_save(model, save_dir, tokenizer)