Beispiel #1
def get_metrics(model: Model,
                total_loss: float,
                num_batches: int,
                reset: bool = False) -> Dict[str, float]:
    Gets the metrics but sets ``"loss"`` to
    the total loss divided by the ``num_batches`` so that
    the ``"loss"`` metric is "average loss per batch".
    metrics = model.get_metrics(reset=reset)
    metrics["loss"] = float(total_loss /
                            num_batches) if num_batches > 0 else 0.0
    return metrics
Beispiel #2
def evaluate(model: Model, instances: Iterable[Instance],
             data_iterator: DataIterator, cuda_device: int,
             batch_weight_key: str) -> Dict[str, Any]:
    with torch.no_grad():

        iterator = data_iterator(instances, num_epochs=1, shuffle=False)"Iterating over dataset")
        generator_tqdm = Tqdm.tqdm(
            iterator, total=data_iterator.get_num_batches(instances))

        # Number of batches in instances.
        batch_count = 0
        # Number of batches where the model produces a loss.
        loss_count = 0
        # Cumulative weighted loss
        total_loss = 0.0
        # Cumulative weight across all batches.
        total_weight = 0.0

        for batch in generator_tqdm:
            batch_count += 1
            batch = nn_util.move_to_device(batch, cuda_device)
            output_dict = model(**batch)
            loss = output_dict.get("loss")

            metrics = model.get_metrics()

            if loss is not None:
                loss_count += 1
                if batch_weight_key:
                    weight = output_dict[batch_weight_key].item()
                    weight = 1.0

                total_weight += weight
                total_loss += loss.item() * weight
                # Report the average loss so far.
                metrics["loss"] = total_loss / total_weight

            if (not HasBeenWarned.tqdm_ignores_underscores and any(
                    metric_name.startswith("_") for metric_name in metrics)):
                logger.warning("Metrics with names beginning with \"_\" will "
                               "not be logged to the tqdm progress bar.")
                HasBeenWarned.tqdm_ignores_underscores = True
            description = ', '.join([
                "%s: %.2f" % (name, value)
                for name, value in metrics.items() if not name.startswith("_")
            ]) + " ||"
            generator_tqdm.set_description(description, refresh=False)

        final_metrics = model.get_metrics(reset=True)
        if loss_count > 0:
            # Sanity check
            if loss_count != batch_count:
                raise RuntimeError(
                    "The model you are trying to evaluate only sometimes " +
                    "produced a loss!")
            final_metrics["loss"] = total_loss / total_weight

        return final_metrics