def validate(self, val_iterator, info, metrics): """Runs one standard validation pass over the val_iterator. This will call ``model.eval()`` and ``torch.no_grad`` when iterating over the validation dataloader. If overriding this method, you can access model, criterion via ``self.model`` and ``self.criterion``. You also do not need to call ``validate_batch`` if overriding this method. Args: val_iterator (iter): Iterable constructed from the validation dataloader. info: (dict): Dictionary for information to be used for custom validation operations. Returns: A dict of metrics from the evaluation. By default, returns "val_accuracy" and "val_loss" which is computed by aggregating "loss" and "correct" values from ``validate_batch`` and dividing it by the sum of ``num_samples`` from all calls to ``self.validate_batch``. """ # switch to evaluate mode self.model.eval() metrics = Metric.convert_metrics_dict(metrics, backend="pytorch") losses = [] total_samples = 0 with torch.no_grad(): for batch_idx, batch in enumerate(val_iterator): batch_info = {"batch_idx": batch_idx} batch_info.update(info) output, target, loss = self.forward_batch(batch, batch_info) num_samples = target.size(0) total_samples += num_samples losses.append(loss.item() * num_samples) for metric in metrics.values(): metric(output, target) result = {name: metric.compute() for name, metric in metrics.items()} result["val_loss"] = sum(losses) / total_samples result["num_samples"] = total_samples return result
def validate(config, model, valid_ld, metrics, validate_batches): import torch from zoo.orca.learn.metrics import Metric model.eval() metrics = Metric.convert_metrics_dict(metrics, backend="pytorch") valid_iter = iter(valid_ld) with torch.no_grad(): for j in range(validate_batches): # Iterate again from the beginning if running out of batches. if j > 0 and j % len(valid_ld) == 0: valid_iter = iter(valid_ld) x, y = next(valid_iter) o = model(x, y) for metric in metrics.values(): metric(o, y) result = {name: metric.compute() for name, metric in metrics.items()} output = "Validation results: " for metric, value in result.items(): output += "{}:{} ".format(metric, value) print(output) return result