def __call__(self, manager): """Execute the statistics extension. Collect statistics for the current state of parameters. Note that this method will merely update its statistic summary, unless the internal trigger is fired. If the trigger is fired, the summary will also be reported and then reset for the next accumulation. Args: manager (~pytorch_pfn_extras.training.ExtensionsManager): Associated manager that invoked this extension. """ statistics = {} for link in self._links: for param_name, param in link.named_parameters(): for attr_name in self._attrs: for function_name, function in self._statistics.items(): # Get parameters as a flattened one-dimensional array # since the statistics function should make no # assumption about the axes params = getattr(param, attr_name).flatten() if (self._skip_nan_params and ( torch.isnan(params).any())): value = float('nan') else: value = function(params) key = self.report_key_template.format( prefix=self._prefix + '/' if self._prefix else '', param_name=param_name, attr_name=attr_name, function_name=function_name ) if (isinstance(value, torch.Tensor) and value.numel() > 1): # Append integer indices to the keys if the # statistic function return multiple values statistics.update({'{}/{}'.format(key, i): v for i, v in enumerate(value)}) else: statistics[key] = value self._summary.add(statistics) if self._trigger(manager): reporting.report(self._summary.compute_mean()) self._summary = reporting.DictSummary() # Clear summary
def evaluate(self): """Evaluates the model and returns a result dictionary. This method runs the evaluation loop over the validation dataset. It accumulates the reported values to :class:`~DictSummary` and returns a dictionary whose values are means computed by the summary. Users can override this method to customize the evaluation routine. Returns: dict: Result dictionary. This dictionary is further reported via :func:`~pytorch_pfn_extras.report` without specifying any observer. """ iterator = self._iterators['main'] if self.eval_hook: self.eval_hook(self) summary = reporting.DictSummary() progress = IterationStatus(len(iterator)) if self._progress_bar: pbar = _IteratorProgressBar(iterator=progress) last_iter = len(iterator) - 1 with _in_eval_mode(self._targets.values()): for idx, batch in enumerate(iterator): last_batch = idx == last_iter progress.current_position = idx observation = {} with reporting.report_scope(observation): if isinstance(batch, (tuple, list)): outs = self.eval_func(*batch) elif isinstance(batch, dict): outs = self.eval_func(**batch) else: outs = self.eval_func(batch) for metric in self._metrics: metric(batch, outs, last_batch) summary.add(observation) if self._progress_bar: pbar.update() if self._progress_bar: pbar.close() return summary.compute_mean()
def run(self, loader: Iterable[Any], *, eval_len: Optional[int] = None) -> None: """Executes the evaluation loop. Args: loader (torch.utils.data.DataLoader): A data loader for evaluation. eval_len (int, optional): The number of iterations per one evaluation epoch. """ # Note: setup_manager is done by the Trainer. self._idxs: 'queue.Queue[int]' = queue.Queue() self._inputs: 'queue.Queue[DictBatch]' = queue.Queue() self._observed: 'queue.Queue[Observation]' = queue.Queue() if eval_len is None: eval_len = len(loader) # type: ignore[arg-type] self._eval_len = eval_len self._summary = reporting.DictSummary() observation: Observation = {} self.handler.eval_loop_begin(self) self._pbar = _progress_bar('validation', self._progress_bar, eval_len) self._update = self._pbar.__enter__() loader_iter = iter(loader) with self._profile or _nullcontext() as prof: with torch.no_grad(): # type: ignore[no-untyped-call] for idx in range(eval_len): try: x = next(loader_iter) except StopIteration: break self._idxs.put(idx) self._inputs.put(x) self._observed.put(observation) with self._reporter.scope(observation): self.handler.eval_step(self, idx, x, self._complete_step) # Some of the DataLoaders might need an explicit break # since they could start cycling on their data if (idx + 1) == eval_len: break if prof is not None: prof.step() # type: ignore[no-untyped-call] # This will report to the trainer main reporter self.handler.eval_loop_end(self) reporting.report(self._summary.compute_mean())
def _init_summary(self): self._summary = reporting.DictSummary()