def get_metrics(model: Model, total_loss: float, num_batches: int, reset: bool = False) -> Dict[str, float]: """ Gets the metrics but sets ``"loss"`` to the total loss divided by the ``num_batches`` so that the ``"loss"`` metric is "average loss per batch". """ metrics = model.get_metrics(reset=reset) metrics["loss"] = float(total_loss / num_batches) if num_batches > 0 else 0.0 return metrics
def evaluate(model: Model, instances: Iterable[Instance], data_iterator: DataIterator, cuda_device: int, batch_weight_key: str) -> Dict[str, Any]: check_for_gpu(cuda_device) with torch.no_grad(): model.eval() iterator = data_iterator(instances, num_epochs=1, shuffle=False) logger.info("Iterating over dataset") generator_tqdm = Tqdm.tqdm( iterator, total=data_iterator.get_num_batches(instances)) # Number of batches in instances. batch_count = 0 # Number of batches where the model produces a loss. loss_count = 0 # Cumulative weighted loss total_loss = 0.0 # Cumulative weight across all batches. total_weight = 0.0 for batch in generator_tqdm: batch_count += 1 batch = nn_util.move_to_device(batch, cuda_device) output_dict = model(**batch) loss = output_dict.get("loss") metrics = model.get_metrics() if loss is not None: loss_count += 1 if batch_weight_key: weight = output_dict[batch_weight_key].item() else: weight = 1.0 total_weight += weight total_loss += loss.item() * weight # Report the average loss so far. metrics["loss"] = total_loss / total_weight if (not HasBeenWarned.tqdm_ignores_underscores and any( metric_name.startswith("_") for metric_name in metrics)): logger.warning("Metrics with names beginning with \"_\" will " "not be logged to the tqdm progress bar.") HasBeenWarned.tqdm_ignores_underscores = True description = ', '.join([ "%s: %.2f" % (name, value) for name, value in metrics.items() if not name.startswith("_") ]) + " ||" generator_tqdm.set_description(description, refresh=False) final_metrics = model.get_metrics(reset=True) if loss_count > 0: # Sanity check if loss_count != batch_count: raise RuntimeError( "The model you are trying to evaluate only sometimes " + "produced a loss!") final_metrics["loss"] = total_loss / total_weight return final_metrics