Пример #1
0
def run_test(test_dataloader, jiant_model: JiantModel, task, device, local_rank, verbose=True):
    if not local_rank == -1:
        return
    jiant_model.eval()
    evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task)
    eval_accumulator = evaluation_scheme.get_accumulator()

    for step, (batch, batch_metadata) in enumerate(
        maybe_tqdm(test_dataloader, desc=f"Eval ({task.name}, Test)", verbose=verbose)
    ):
        batch = batch.to(device)

        with torch.no_grad():
            model_output = wrap_jiant_forward(
                jiant_model=jiant_model, batch=batch, task=task, compute_loss=False,
            )
        batch_logits = model_output.logits.detach().cpu().numpy()
        eval_accumulator.update(
            batch_logits=batch_logits, batch_loss=0, batch=batch, batch_metadata=batch_metadata,
        )
    return {
        "preds": evaluation_scheme.get_preds_from_accumulator(
            task=task, accumulator=eval_accumulator,
        ),
        "accumulator": eval_accumulator,
    }
Пример #2
0
def run_val(
    val_dataloader,
    val_labels,
    jiant_model: JiantModel,
    task,
    device,
    local_rank,
    return_preds=False,
    verbose=True,
):
    # Reminder:
    #   val_dataloader contains mostly PyTorch-relevant info
    #   val_labels might contain more details information needed for full evaluation
    if not local_rank == -1:
        return
    jiant_model.eval()
    total_eval_loss = 0
    nb_eval_steps, nb_eval_examples = 0, 0
    evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task)
    eval_accumulator = evaluation_scheme.get_accumulator()

    for step, (batch, batch_metadata) in enumerate(
        maybe_tqdm(val_dataloader, desc=f"Eval ({task.name}, Val)", verbose=verbose)
    ):
        batch = batch.to(device)

        with torch.no_grad():
            model_output = wrap_jiant_forward(
                jiant_model=jiant_model, batch=batch, task=task, compute_loss=True,
            )
        batch_logits = model_output.logits.detach().cpu().numpy()
        batch_loss = model_output.loss.mean().item()
        total_eval_loss += batch_loss
        eval_accumulator.update(
            batch_logits=batch_logits,
            batch_loss=batch_loss,
            batch=batch,
            batch_metadata=batch_metadata,
        )

        nb_eval_examples += len(batch)
        nb_eval_steps += 1
    eval_loss = total_eval_loss / nb_eval_steps
    tokenizer = (
        jiant_model.tokenizer
        if not torch_utils.is_data_parallel(jiant_model)
        else jiant_model.module.tokenizer
    )
    output = {
        "accumulator": eval_accumulator,
        "loss": eval_loss,
        "metrics": evaluation_scheme.compute_metrics_from_accumulator(
            task=task, accumulator=eval_accumulator, labels=val_labels, tokenizer=tokenizer,
        ),
    }
    if return_preds:
        output["preds"] = evaluation_scheme.get_preds_from_accumulator(
            task=task, accumulator=eval_accumulator,
        )
    return output
Пример #3
0
def run_test(
    test_dataloader,
    jiant_model: JiantModel,
    task,
    device,
    local_rank,
    verbose=True,
    return_preds=True,
    return_logits=True,
    return_encoder_output: bool = False,
):
    if not local_rank == -1:
        return
    jiant_model.eval()
    evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task)
    eval_accumulator = evaluation_scheme.get_accumulator()

    encoder_outputs = []
    for step, (batch, batch_metadata) in enumerate(
            maybe_tqdm(test_dataloader,
                       desc=f"Eval ({task.name}, Test)",
                       verbose=verbose)):
        regular_log(logger, step, interval=10, tag='test')

        batch = batch.to(device)

        with torch.no_grad():
            model_outputs = wrap_jiant_forward(
                jiant_model=jiant_model,
                batch=batch,
                task=task,
                compute_loss=False,
                get_encoder_output=return_encoder_output,
            )
            if return_encoder_output:
                model_output, encoder_output = model_outputs
                encoder_outputs.append(encoder_output)
            else:
                model_output = model_outputs
        batch_logits = model_output.logits.detach().cpu().numpy()
        eval_accumulator.update(
            batch_logits=batch_logits,
            batch_loss=0,
            batch=batch,
            batch_metadata=batch_metadata,
        )
    output = {
        "accumulator": eval_accumulator,
    }
    if return_preds:
        output["preds"] = evaluation_scheme.get_preds_from_accumulator(
            task=task,
            accumulator=eval_accumulator,
        )
        if isinstance(eval_accumulator,
                      evaluate.ConcatenateLogitsAccumulator) and return_logits:
            output["logits"] = eval_accumulator.get_accumulated()
    if return_encoder_output:
        output["encoder_outputs_pooled"] = np.concatenate(
            [encoder_output.pooled for encoder_output in encoder_outputs])
        output["encoder_outputs_unpooled"] = np.concatenate(
            [encoder_output.unpooled for encoder_output in encoder_outputs])
    return output
Пример #4
0
def run_val(
    val_dataloader,
    val_labels,
    jiant_model: JiantModel,
    task,
    device,
    local_rank,
    tf_writer: SummaryWriter,
    global_step: Optional[int] = None,
    phase=None,
    return_preds=False,
    return_logits=True,
    return_encoder_output: bool = False,
    verbose=True,
    split='valid',
):
    # Reminder:
    #   val_dataloader contains mostly PyTorch-relevant info
    #   val_labels might contain more details information needed for full evaluation
    has_labels = True  # TODO: データセットにラベルが存在するかどうかを自動判定する.

    if not local_rank == -1:
        return
    jiant_model.eval()
    evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task)
    eval_accumulator = evaluation_scheme.get_accumulator()
    total_eval_loss = 0
    nb_eval_steps, nb_eval_examples = 0, 0

    encoder_outputs = []
    for step, (batch, batch_metadata) in enumerate(
            maybe_tqdm(val_dataloader,
                       desc=f"Eval ({task.name}, {str(phase)})",
                       verbose=verbose)):
        regular_log(logger, step, interval=10, tag=split)

        batch = batch.to(device)

        with torch.no_grad():
            model_outputs = wrap_jiant_forward(
                jiant_model=jiant_model,
                batch=batch,
                task=task,
                compute_loss=has_labels,
                get_encoder_output=return_encoder_output,
            )
            if return_encoder_output:
                model_output, encoder_output = model_outputs
                encoder_outputs.append(encoder_output)
            else:
                model_output = model_outputs
        batch_logits = model_output.logits.detach().cpu().numpy()
        if has_labels:
            batch_loss = model_output.loss.mean().item()
        else:
            batch_loss = 0
        total_eval_loss += batch_loss
        eval_accumulator.update(
            batch_logits=batch_logits,
            batch_loss=batch_loss,
            batch=batch,
            batch_metadata=batch_metadata,
        )

        nb_eval_examples += len(batch)
        nb_eval_steps += 1

    eval_loss = total_eval_loss / nb_eval_steps
    output = {
        "accumulator": eval_accumulator,
    }

    if has_labels:
        tokenizer = (jiant_model.tokenizer
                     if not torch_utils.is_data_parallel(jiant_model) else
                     jiant_model.module.tokenizer)
        metrics = evaluation_scheme.compute_metrics_from_accumulator(
            task=task,
            accumulator=eval_accumulator,
            labels=val_labels,
            tokenizer=tokenizer,
        )

        output.update({
            "loss": eval_loss,
            "metrics": metrics,
        })

        if global_step is not None:
            for metric_name, metric_value in metrics.minor.items():
                tf_writer.add_scalar(f'{split}/{metric_name}',
                                     metric_value,
                                     global_step=global_step)

    if return_preds:
        output["preds"] = evaluation_scheme.get_preds_from_accumulator(
            task=task,
            accumulator=eval_accumulator,
        )
        if isinstance(eval_accumulator,
                      evaluate.ConcatenateLogitsAccumulator) and return_logits:
            output["logits"] = eval_accumulator.get_accumulated()
    if return_encoder_output:
        output["encoder_outputs_pooled"] = np.concatenate(
            [encoder_output.pooled for encoder_output in encoder_outputs])
        output["encoder_outputs_unpooled"] = np.concatenate(
            [encoder_output.unpooled for encoder_output in encoder_outputs])
    if global_step is not None:
        tf_writer.add_scalar(f'{split}/loss',
                             eval_loss,
                             global_step=global_step)

    tf_writer.flush()
    return output