예제 #1
0
def evaluate(data_eval, model, nsp_loss, mlm_loss, vocab_size, ctx):
    """Evaluation function."""
    mlm_metric = MaskedAccuracy()
    nsp_metric = MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    eval_begin_time = time.time()
    begin_time = time.time()
    step_num = 0
    local_mlm_loss = local_nsp_loss = 0
    total_mlm_loss = total_nsp_loss = 0
    local_num_tks = 0
    for _, dataloader in enumerate(data_eval):
        for _, data in enumerate(dataloader):
            step_num += 1

            data_list = split_and_load(data, ctx)
            loss_list = []
            ns_label_list, ns_pred_list = [], []
            mask_label_list, mask_pred_list, mask_weight_list = [], [], []
            for data in data_list:
                out = forward(data, model, mlm_loss, nsp_loss, vocab_size)
                (ls, next_sentence_label, classified, masked_id, decoded,
                 masked_weight, ls1, ls2, valid_length) = out
                loss_list.append(ls)
                ns_label_list.append(next_sentence_label)
                ns_pred_list.append(classified)
                mask_label_list.append(masked_id)
                mask_pred_list.append(decoded)
                mask_weight_list.append(masked_weight)

                local_mlm_loss += ls1.as_in_context(mx.cpu())
                local_nsp_loss += ls2.as_in_context(mx.cpu())
                local_num_tks += valid_length.sum().as_in_context(mx.cpu())
            nsp_metric.update(ns_label_list, ns_pred_list)
            mlm_metric.update(mask_label_list, mask_pred_list,
                              mask_weight_list)

            # logging
            if (step_num + 1) % (args.log_interval) == 0:
                total_mlm_loss += local_mlm_loss
                total_nsp_loss += local_nsp_loss
                log(begin_time, local_num_tks, local_mlm_loss, local_nsp_loss,
                    step_num, mlm_metric, nsp_metric, None)
                begin_time = time.time()
                local_mlm_loss = local_nsp_loss = local_num_tks = 0
                mlm_metric.reset_local()
                nsp_metric.reset_local()

    mx.nd.waitall()
    eval_end_time = time.time()
    total_mlm_loss /= step_num
    total_nsp_loss /= step_num
    logging.info(
        'mlm_loss={:.3f}\tmlm_acc={:.1f}\tnsp_loss={:.3f}\tnsp_acc={:.1f}\t'.
        format(total_mlm_loss.asscalar(),
               mlm_metric.get_global()[1] * 100, total_nsp_loss.asscalar(),
               nsp_metric.get_global()[1] * 100))
    logging.info('Eval cost={:.1f}s'.format(eval_end_time - eval_begin_time))
예제 #2
0
def evaluate(data_eval, model, mlm_loss, vocab_size, ctx, log_interval, dtype,
             mlm_weight=1.0, teacher_ce_loss=None, teacher_mse_loss=None, teacher_model=None, teacher_ce_weight=0.0,
             distillation_temperature=1.0, log_tb=None):
    """Evaluation function."""
    logging.info('Running evaluation ... ')
    mlm_metric = MaskedAccuracy()
    mlm_metric.reset()

    eval_begin_time = time.time()
    begin_time = time.time()
    step_num = 0
    running_mlm_loss = 0
    total_mlm_loss = 0
    running_teacher_ce_loss = running_teacher_mse_loss = 0
    total_teacher_ce_loss = total_teacher_mse_loss = 0
    running_num_tks = 0

    for _, dataloader in tqdm(enumerate(data_eval), desc="Evaluation"):
        step_num += 1
        data_list = [[seq.as_in_context(context) for seq in shard]
                     for context, shard in zip(ctx, dataloader)]
        loss_list = []
        ns_label_list, ns_pred_list = [], []
        mask_label_list, mask_pred_list, mask_weight_list = [], [], []
        for data in data_list:
            out = forward(data, model, mlm_loss, vocab_size, dtype, is_eval=True,
                          mlm_weight=mlm_weight,
                          teacher_ce_loss=teacher_ce_loss, teacher_mse_loss=teacher_mse_loss,
                          teacher_model=teacher_model, teacher_ce_weight=teacher_ce_weight,
                          distillation_temperature=distillation_temperature)
            (loss_val, next_sentence_label, classified, masked_id,
             decoded, masked_weight, mlm_loss_val, teacher_ce_loss_val, teacher_mse_loss_val, valid_length) = out
            loss_list.append(loss_val)
            ns_label_list.append(next_sentence_label)
            ns_pred_list.append(classified)
            mask_label_list.append(masked_id)
            mask_pred_list.append(decoded)
            mask_weight_list.append(masked_weight)

            running_mlm_loss += mlm_loss_val.as_in_context(mx.cpu())
            running_num_tks += valid_length.sum().as_in_context(mx.cpu())
            running_teacher_ce_loss += teacher_ce_loss_val.as_in_context(
                mx.cpu())
            running_teacher_mse_loss += teacher_mse_loss_val.as_in_context(
                mx.cpu())
        mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list)

        # logging
        if (step_num + 1) % (log_interval) == 0:
            total_mlm_loss += running_mlm_loss
            total_teacher_ce_loss += running_teacher_ce_loss
            total_teacher_mse_loss += running_teacher_mse_loss
            log("eval ",
                begin_time,
                running_num_tks,
                running_mlm_loss,
                running_teacher_ce_loss,
                running_teacher_mse_loss,
                step_num,
                mlm_metric,
                None,
                log_interval,
                model=model,
                log_tb=log_tb)
            begin_time = time.time()
            running_mlm_loss = running_num_tks = 0
            running_teacher_ce_loss = running_teacher_mse_loss = 0
            mlm_metric.reset_local()

    mx.nd.waitall()
    eval_end_time = time.time()
    # accumulate losses from last few batches, too
    if running_mlm_loss != 0:
        total_mlm_loss += running_mlm_loss
        total_teacher_ce_loss += running_teacher_ce_loss
        total_teacher_mse_loss += running_teacher_mse_loss
    total_mlm_loss /= step_num
    total_teacher_ce_loss /= step_num
    total_teacher_mse_loss /= step_num
    logging.info('Eval mlm_loss={:.3f}\tmlm_acc={:.1f}\tteacher_ce={:.2e}\tteacher_mse={:.2e}'
                 .format(total_mlm_loss.asscalar(), mlm_metric.get_global()[1] * 100,
                         total_teacher_ce_loss.asscalar(), total_teacher_mse_loss.asscalar()))
    logging.info('Eval cost={:.1f}s'.format(eval_end_time - eval_begin_time))
예제 #3
0
def evaluate(data_eval, model, nsp_loss, mlm_loss, vocab_size, ctx):
    """Evaluation function."""
    mlm_metric = MaskedAccuracy()
    nsp_metric = MaskedAccuracy()
    mlm_metric.reset()
    nsp_metric.reset()

    eval_begin_time = time.time()
    begin_time = time.time()
    step_num = 0

    # Total loss for the whole dataset
    total_mlm_loss = total_nsp_loss = 0

    # Running loss, reset when a log is emitted
    running_mlm_loss = running_nsp_loss = 0
    running_num_tks = 0
    for _, dataloader in enumerate(data_eval):
        for _, data in enumerate(dataloader):
            step_num += 1

            data_list = split_and_load(data, ctx)
            loss_list = []
            ns_label_list, ns_pred_list = [], []
            mask_label_list, mask_pred_list, mask_weight_list = [], [], []

            # Run inference on the batch, collect the predictions and losses
            batch_mlm_loss = batch_nsp_loss = 0
            for data in data_list:
                out = forward(data, model, mlm_loss, nsp_loss, vocab_size)
                (ls, next_sentence_label, classified, masked_id,
                 decoded, masked_weight, ls1, ls2, valid_length) = out

                loss_list.append(ls)
                ns_label_list.append(next_sentence_label)
                ns_pred_list.append(classified)
                mask_label_list.append(masked_id)
                mask_pred_list.append(decoded)
                mask_weight_list.append(masked_weight)

                batch_mlm_loss += ls1.as_in_context(mx.cpu())
                batch_nsp_loss += ls2.as_in_context(mx.cpu())
                running_num_tks += valid_length.sum().as_in_context(mx.cpu())

            running_mlm_loss += batch_mlm_loss
            running_nsp_loss += batch_nsp_loss
            total_mlm_loss += batch_mlm_loss
            total_nsp_loss += batch_nsp_loss

            nsp_metric.update(ns_label_list, ns_pred_list)
            mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list)

            # Log and reset running loss
            if (step_num + 1) % (args.log_interval) == 0:
                log(begin_time, running_num_tks, running_mlm_loss, running_nsp_loss,
                    step_num, mlm_metric, nsp_metric, None)
                begin_time = time.time()
                running_mlm_loss = running_nsp_loss = running_num_tks = 0
                mlm_metric.reset_running()
                nsp_metric.reset_running()

    mx.nd.waitall()
    eval_end_time = time.time()
    total_mlm_loss /= step_num
    total_nsp_loss /= step_num
    logging.info('mlm_loss={:.3f}\tmlm_acc={:.1f}\tnsp_loss={:.3f}\tnsp_acc={:.1f}\t'
                 .format(total_mlm_loss.asscalar(), mlm_metric.get_global()[1] * 100,
                         total_nsp_loss.asscalar(), nsp_metric.get_global()[1] * 100))
    logging.info('Eval cost={:.1f}s'.format(eval_end_time - eval_begin_time))