Пример #1
0
def evaluate(data_eval, model, nsp_loss, mlm_loss, vocab_size, ctx,
             log_interval, dtype):
    """Evaluation function."""
    mlm_metric = MaskedAccuracy()
    nsp_metric = MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    eval_begin_time = time.time()
    begin_time = time.time()
    step_num = 0
    running_mlm_loss = running_nsp_loss = 0
    total_mlm_loss = total_nsp_loss = 0
    running_num_tks = 0
    for _, dataloader in enumerate(data_eval):
        for _, data_batch in enumerate(dataloader):
            step_num += 1

            data_list = split_and_load(data_batch, ctx)
            loss_list = []
            ns_label_list, ns_pred_list = [], []
            mask_label_list, mask_pred_list, mask_weight_list = [], [], []
            for data in data_list:
                out = forward(data, model, mlm_loss, nsp_loss, vocab_size,
                              dtype)
                (ls, next_sentence_label, classified, masked_id, decoded,
                 masked_weight, ls1, ls2, valid_length) = out
                loss_list.append(ls)
                ns_label_list.append(next_sentence_label)
                ns_pred_list.append(classified)
                mask_label_list.append(masked_id)
                mask_pred_list.append(decoded)
                mask_weight_list.append(masked_weight)

                running_mlm_loss += ls1.as_in_context(mx.cpu())
                running_nsp_loss += ls2.as_in_context(mx.cpu())
                running_num_tks += valid_length.sum().as_in_context(mx.cpu())
            nsp_metric.update(ns_label_list, ns_pred_list)
            mlm_metric.update(mask_label_list, mask_pred_list,
                              mask_weight_list)

            # logging
            if (step_num + 1) % (log_interval) == 0:
                total_mlm_loss += running_mlm_loss
                total_nsp_loss += running_nsp_loss
                log(begin_time, running_num_tks, running_mlm_loss,
                    running_nsp_loss, step_num, mlm_metric, nsp_metric, None,
                    log_interval)
                begin_time = time.time()
                running_mlm_loss = running_nsp_loss = running_num_tks = 0
                mlm_metric.reset_local()
                nsp_metric.reset_local()

    mx.nd.waitall()
    eval_end_time = time.time()
    # accumulate losses from last few batches, too
    if running_mlm_loss != 0:
        total_mlm_loss += running_mlm_loss
        total_nsp_loss += running_nsp_loss
    total_mlm_loss /= step_num
    total_nsp_loss /= step_num
    logging.info(
        'mlm_loss={:.3f}\tmlm_acc={:.1f}\tnsp_loss={:.3f}\tnsp_acc={:.1f}\t'.
        format(total_mlm_loss.asscalar(),
               mlm_metric.get_global()[1] * 100, total_nsp_loss.asscalar(),
               nsp_metric.get_global()[1] * 100))
    logging.info('Eval cost={:.1f}s'.format(eval_end_time - eval_begin_time))
Пример #2
0
def train(data_train, dataset_eval, model, teacher_model, mlm_loss,
          teacher_ce_loss, teacher_mse_loss, vocab_size, ctx,
          teacher_ce_weight, distillation_temperature, mlm_weight, log_tb):
    """Training function."""
    params = model.collect_params()
    if params is not None:
        hvd.broadcast_parameters(params, root_rank=0)

    mlm_metric = MaskedAccuracy()
    mlm_metric.reset()

    logging.debug('Creating distributed trainer...')
    lr = args.lr
    optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01}
    if args.dtype == 'float16':
        optim_params['multi_precision'] = True

    dynamic_loss_scale = args.dtype == 'float16'
    if dynamic_loss_scale:
        loss_scale_param = {'scale_window': 2000 / num_workers}
    else:
        loss_scale_param = None
    trainer = hvd.DistributedTrainer(params, 'bertadam', optim_params)

    if args.dtype == 'float16':
        fp16_trainer = FP16Trainer(trainer,
                                   dynamic_loss_scale=dynamic_loss_scale,
                                   loss_scaler_params=loss_scale_param)
        trainer_step = lambda: fp16_trainer.step(1, max_norm=1 * num_workers)
    else:
        trainer_step = lambda: trainer.step(1)

    if args.start_step:
        out_dir = os.path.join(args.ckpt_dir, f"checkpoint_{args.start_step}")
        state_path = os.path.join(
            out_dir, '%07d.states.%02d' % (args.start_step, local_rank))
        logging.info('Loading trainer state from %s', state_path)
        nlp.utils.load_states(trainer, state_path)

    accumulate = args.accumulate
    num_train_steps = args.num_steps
    warmup_ratio = args.warmup_ratio
    num_warmup_steps = int(num_train_steps * warmup_ratio)
    params = [
        p for p in model.collect_params().values() if p.grad_req != 'null'
    ]
    param_dict = model.collect_params()

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    if accumulate > 1:
        for p in params:
            p.grad_req = 'add'

    train_begin_time = time.time()
    begin_time = time.time()
    running_mlm_loss, running_teacher_ce_loss, running_teacher_mse_loss = 0, 0, 0
    running_num_tks = 0
    batch_num = 0
    step_num = args.start_step

    logging.debug('Training started')

    pbar = tqdm(total=num_train_steps, desc="Training:")

    while step_num < num_train_steps:
        for raw_batch_num, data_batch in enumerate(data_train):
            sys.stdout.flush()
            if step_num >= num_train_steps:
                break
            if batch_num % accumulate == 0:
                step_num += 1
                # if accumulate > 1, grad_req is set to 'add', and zero_grad is
                # required
                if accumulate > 1:
                    param_dict.zero_grad()
                # update learning rate
                if step_num <= num_warmup_steps:
                    new_lr = lr * step_num / num_warmup_steps
                else:
                    offset = lr * step_num / num_train_steps
                    new_lr = lr - offset
                trainer.set_learning_rate(new_lr)
                if args.profile:
                    profile(step_num,
                            10,
                            14,
                            profile_name=args.profile + str(rank))

            # load data
            if args.use_avg_len:
                data_list = [[[s.as_in_context(context) for s in seq]
                              for seq in shard]
                             for context, shard in zip([ctx], data_batch)]
            else:
                data_list = list(split_and_load(data_batch, [ctx]))
            #data = data_list[0]
            data = data_list

            # forward
            with mx.autograd.record():
                (loss_val, ns_label, classified, masked_id, decoded,
                 masked_weight, mlm_loss_val, teacher_ce_loss_val,
                 teacher_mse_loss_val, valid_len) = forward(
                     data,
                     model,
                     mlm_loss,
                     vocab_size,
                     args.dtype,
                     mlm_weight=mlm_weight,
                     teacher_ce_loss=teacher_ce_loss,
                     teacher_mse_loss=teacher_mse_loss,
                     teacher_model=teacher_model,
                     teacher_ce_weight=teacher_ce_weight,
                     distillation_temperature=distillation_temperature)
                loss_val = loss_val / accumulate
                # backward
                if args.dtype == 'float16':
                    fp16_trainer.backward(loss_val)
                else:
                    loss_val.backward()

            running_mlm_loss += mlm_loss_val.as_in_context(mx.cpu())
            running_teacher_ce_loss += teacher_ce_loss_val.as_in_context(
                mx.cpu())
            running_teacher_mse_loss += teacher_mse_loss_val.as_in_context(
                mx.cpu())
            running_num_tks += valid_len.sum().as_in_context(mx.cpu())

            # update
            if (batch_num + 1) % accumulate == 0:
                # step() performs 3 things:
                # 1. allreduce gradients from all workers
                # 2. checking the global_norm of gradients and clip them if necessary
                # 3. averaging the gradients and apply updates
                trainer_step()

            mlm_metric.update([masked_id], [decoded], [masked_weight])

            # logging
            if step_num % args.log_interval == 0 and batch_num % accumulate == 0:
                log("train ",
                    begin_time,
                    running_num_tks,
                    running_mlm_loss / accumulate,
                    running_teacher_ce_loss / accumulate,
                    running_teacher_mse_loss / accumulate,
                    step_num,
                    mlm_metric,
                    trainer,
                    args.log_interval,
                    model=model,
                    log_tb=log_tb,
                    is_master_node=is_master_node)
                begin_time = time.time()
                running_mlm_loss = running_teacher_ce_loss = running_teacher_mse_loss = running_num_tks = 0
                mlm_metric.reset_local()

            # saving checkpoints
            if step_num % args.ckpt_interval == 0 and batch_num % accumulate == 0:
                if is_master_node:
                    out_dir = os.path.join(args.ckpt_dir,
                                           f"checkpoint_{step_num}")
                    if not os.path.isdir(out_dir):
                        nlp.utils.mkdir(out_dir)
                    save_states(step_num, trainer, out_dir, local_rank)
                    if local_rank == 0:
                        save_parameters(step_num, model, out_dir)
                if data_eval:
                    dataset_eval = get_pretrain_data_npz(
                        data_eval, args.batch_size_eval, 1, False, False, 1)
                    evaluate(dataset_eval,
                             model,
                             mlm_loss,
                             len(vocab), [ctx],
                             args.log_interval,
                             args.dtype,
                             mlm_weight=mlm_weight,
                             teacher_ce_loss=teacher_ce_loss,
                             teacher_mse_loss=teacher_mse_loss,
                             teacher_model=teacher_model,
                             teacher_ce_weight=teacher_ce_weight,
                             distillation_temperature=distillation_temperature,
                             log_tb=log_tb)

            batch_num += 1
        pbar.update(1)
        del data_batch
    if is_master_node:
        out_dir = os.path.join(args.ckpt_dir, "checkpoint_last")
        if not os.path.isdir(out_dir):
            os.mkdir(out_dir)
        save_states(step_num, trainer, out_dir, local_rank)
        if local_rank == 0:
            save_parameters(step_num, model, args.ckpt_dir)
    mx.nd.waitall()
    train_end_time = time.time()
    pbar.close()
    logging.info('Train cost={:.1f}s'.format(train_end_time -
                                             train_begin_time))
Пример #3
0
def train(data_train, model, nsp_loss, mlm_loss, vocab_size, ctx, store):
    """Training function."""
    mlm_metric = MaskedAccuracy()
    nsp_metric = MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    lr = args.lr
    optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01}
    if args.dtype == 'float16':
        optim_params['multi_precision'] = True

    trainer = gluon.Trainer(model.collect_params(), 'bertadam', optim_params,
                            update_on_kvstore=False, kvstore=store)
    dynamic_loss_scale = args.dtype == 'float16'
    fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale)

    if args.ckpt_dir and args.start_step:
        state_path = os.path.join(args.ckpt_dir, '%07d.states' % args.start_step)
        logging.info('Loading trainer state from %s', state_path)
        trainer.load_states(state_path)

    accumulate = args.accumulate
    num_train_steps = args.num_steps
    warmup_ratio = args.warmup_ratio
    num_warmup_steps = int(num_train_steps * warmup_ratio)
    params = [p for p in model.collect_params().values() if p.grad_req != 'null']

    # Do not apply weight decay on LayerNorm and bias terms
    for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    for p in params:
        p.grad_req = 'add'

    train_begin_time = time.time()
    begin_time = time.time()
    local_mlm_loss = 0
    local_nsp_loss = 0
    local_num_tks = 0
    batch_num = 0
    step_num = args.start_step

    parallel_model = ParallelBERT(model, mlm_loss, nsp_loss, vocab_size,
                                  store.num_workers * accumulate, trainer=fp16_trainer)
    num_ctxes = len(ctx)
    parallel = Parallel(num_ctxes, parallel_model)

    while step_num < num_train_steps:
        for _, dataloader in enumerate(data_train):
            if step_num >= num_train_steps:
                break
            for _, data_batch in enumerate(dataloader):
                if step_num >= num_train_steps:
                    break
                if batch_num % accumulate == 0:
                    step_num += 1
                    # zero grad
                    model.collect_params().zero_grad()
                    # update learning rate
                    if step_num <= num_warmup_steps:
                        new_lr = lr * step_num / num_warmup_steps
                    else:
                        offset = lr * step_num / num_train_steps
                        new_lr = lr - offset
                    trainer.set_learning_rate(new_lr)
                    if args.profile:
                        profile(step_num, 10, 12)
                if args.by_token:
                    data_list = [[seq.as_in_context(context) for seq in shard]
                                 for context, shard in zip(ctx, data_batch)]
                else:
                    if data_batch[0].shape[0] < len(ctx):
                        continue
                    data_list = split_and_load(data_batch, ctx)

                ns_label_list, ns_pred_list = [], []
                mask_label_list, mask_pred_list, mask_weight_list = [], [], []

                # parallel forward / backward
                for data in data_list:
                    parallel.put(data)
                for _ in range(len(ctx)):
                    (_, next_sentence_label, classified, masked_id,
                     decoded, masked_weight, ls1, ls2, valid_length) = parallel.get()
                    ns_label_list.append(next_sentence_label)
                    ns_pred_list.append(classified)
                    mask_label_list.append(masked_id)
                    mask_pred_list.append(decoded)
                    mask_weight_list.append(masked_weight)
                    local_mlm_loss += ls1.as_in_context(mx.cpu()) / num_ctxes
                    local_nsp_loss += ls2.as_in_context(mx.cpu()) / num_ctxes
                    local_num_tks += valid_length.sum().as_in_context(mx.cpu())

                # update
                if (batch_num + 1) % accumulate == 0:
                    fp16_trainer.step(1, max_norm=1)
                nsp_metric.update(ns_label_list, ns_pred_list)
                mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list)
                # logging
                if (step_num + 1) % (args.log_interval) == 0 and (batch_num + 1) % accumulate == 0:
                    log(begin_time, local_num_tks, local_mlm_loss / accumulate,
                        local_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric, trainer)
                    begin_time = time.time()
                    local_mlm_loss = local_nsp_loss = local_num_tks = 0
                    mlm_metric.reset_local()
                    nsp_metric.reset_local()

                # saving checkpoints
                if args.ckpt_dir and (step_num + 1) % (args.ckpt_interval) == 0 \
                   and (batch_num + 1) % accumulate == 0:
                    save_params(step_num, args, model, trainer)
                batch_num += 1
    save_params(step_num, args, model, trainer)
    mx.nd.waitall()
    train_end_time = time.time()
    logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
Пример #4
0
def evaluate(data_eval, model, mlm_loss, vocab_size, ctx, log_interval, dtype,
             mlm_weight=1.0, teacher_ce_loss=None, teacher_mse_loss=None, teacher_model=None, teacher_ce_weight=0.0,
             distillation_temperature=1.0, log_tb=None):
    """Evaluation function."""
    logging.info('Running evaluation ... ')
    mlm_metric = MaskedAccuracy()
    mlm_metric.reset()

    eval_begin_time = time.time()
    begin_time = time.time()
    step_num = 0
    running_mlm_loss = 0
    total_mlm_loss = 0
    running_teacher_ce_loss = running_teacher_mse_loss = 0
    total_teacher_ce_loss = total_teacher_mse_loss = 0
    running_num_tks = 0

    for _, dataloader in tqdm(enumerate(data_eval), desc="Evaluation"):
        step_num += 1
        data_list = [[seq.as_in_context(context) for seq in shard]
                     for context, shard in zip(ctx, dataloader)]
        loss_list = []
        ns_label_list, ns_pred_list = [], []
        mask_label_list, mask_pred_list, mask_weight_list = [], [], []
        for data in data_list:
            out = forward(data, model, mlm_loss, vocab_size, dtype, is_eval=True,
                          mlm_weight=mlm_weight,
                          teacher_ce_loss=teacher_ce_loss, teacher_mse_loss=teacher_mse_loss,
                          teacher_model=teacher_model, teacher_ce_weight=teacher_ce_weight,
                          distillation_temperature=distillation_temperature)
            (loss_val, next_sentence_label, classified, masked_id,
             decoded, masked_weight, mlm_loss_val, teacher_ce_loss_val, teacher_mse_loss_val, valid_length) = out
            loss_list.append(loss_val)
            ns_label_list.append(next_sentence_label)
            ns_pred_list.append(classified)
            mask_label_list.append(masked_id)
            mask_pred_list.append(decoded)
            mask_weight_list.append(masked_weight)

            running_mlm_loss += mlm_loss_val.as_in_context(mx.cpu())
            running_num_tks += valid_length.sum().as_in_context(mx.cpu())
            running_teacher_ce_loss += teacher_ce_loss_val.as_in_context(
                mx.cpu())
            running_teacher_mse_loss += teacher_mse_loss_val.as_in_context(
                mx.cpu())
        mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list)

        # logging
        if (step_num + 1) % (log_interval) == 0:
            total_mlm_loss += running_mlm_loss
            total_teacher_ce_loss += running_teacher_ce_loss
            total_teacher_mse_loss += running_teacher_mse_loss
            log("eval ",
                begin_time,
                running_num_tks,
                running_mlm_loss,
                running_teacher_ce_loss,
                running_teacher_mse_loss,
                step_num,
                mlm_metric,
                None,
                log_interval,
                model=model,
                log_tb=log_tb)
            begin_time = time.time()
            running_mlm_loss = running_num_tks = 0
            running_teacher_ce_loss = running_teacher_mse_loss = 0
            mlm_metric.reset_local()

    mx.nd.waitall()
    eval_end_time = time.time()
    # accumulate losses from last few batches, too
    if running_mlm_loss != 0:
        total_mlm_loss += running_mlm_loss
        total_teacher_ce_loss += running_teacher_ce_loss
        total_teacher_mse_loss += running_teacher_mse_loss
    total_mlm_loss /= step_num
    total_teacher_ce_loss /= step_num
    total_teacher_mse_loss /= step_num
    logging.info('Eval mlm_loss={:.3f}\tmlm_acc={:.1f}\tteacher_ce={:.2e}\tteacher_mse={:.2e}'
                 .format(total_mlm_loss.asscalar(), mlm_metric.get_global()[1] * 100,
                         total_teacher_ce_loss.asscalar(), total_teacher_mse_loss.asscalar()))
    logging.info('Eval cost={:.1f}s'.format(eval_end_time - eval_begin_time))
Пример #5
0
def evaluate(data_eval, model, nsp_loss, mlm_loss, vocab_size, ctx):
    """Evaluation function."""
    mlm_metric = MaskedAccuracy()
    nsp_metric = MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    eval_begin_time = time.time()
    begin_time = time.time()
    step_num = 0

    # Total loss for the whole dataset
    total_mlm_loss = total_nsp_loss = 0

    # Running loss, reset when a log is emitted
    running_mlm_loss = running_nsp_loss = 0
    running_num_tks = 0
    for _, dataloader in enumerate(data_eval):
        for _, data in enumerate(dataloader):
            step_num += 1

            data_list = split_and_load(data, ctx)
            loss_list = []
            ns_label_list, ns_pred_list = [], []
            mask_label_list, mask_pred_list, mask_weight_list = [], [], []

            # Run inference on the batch, collect the predictions and losses
            batch_mlm_loss = batch_nsp_loss = 0
            for data in data_list:
                out = forward(data, model, mlm_loss, nsp_loss, vocab_size)
                (ls, next_sentence_label, classified, masked_id,
                 decoded, masked_weight, ls1, ls2, valid_length) = out

                loss_list.append(ls)
                ns_label_list.append(next_sentence_label)
                ns_pred_list.append(classified)
                mask_label_list.append(masked_id)
                mask_pred_list.append(decoded)
                mask_weight_list.append(masked_weight)

                batch_mlm_loss += ls1.as_in_context(mx.cpu())
                batch_nsp_loss += ls2.as_in_context(mx.cpu())
                running_num_tks += valid_length.sum().as_in_context(mx.cpu())

            running_mlm_loss += batch_mlm_loss
            running_nsp_loss += batch_nsp_loss
            total_mlm_loss += batch_mlm_loss
            total_nsp_loss += batch_nsp_loss

            nsp_metric.update(ns_label_list, ns_pred_list)
            mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list)

            # Log and reset running loss
            if (step_num + 1) % (args.log_interval) == 0:
                log(begin_time, running_num_tks, running_mlm_loss, running_nsp_loss,
                    step_num, mlm_metric, nsp_metric, None)
                begin_time = time.time()
                running_mlm_loss = running_nsp_loss = running_num_tks = 0
                mlm_metric.reset_running()
                nsp_metric.reset_running()

    mx.nd.waitall()
    eval_end_time = time.time()
    total_mlm_loss /= step_num
    total_nsp_loss /= step_num
    logging.info('mlm_loss={:.3f}\tmlm_acc={:.1f}\tnsp_loss={:.3f}\tnsp_acc={:.1f}\t'
                 .format(total_mlm_loss.asscalar(), mlm_metric.get_global()[1] * 100,
                         total_nsp_loss.asscalar(), nsp_metric.get_global()[1] * 100))
    logging.info('Eval cost={:.1f}s'.format(eval_end_time - eval_begin_time))