Exemple #1
0
    def __call__(self, manager):
        """Execute the statistics extension.

        Collect statistics for the current state of parameters.

        Note that this method will merely update its statistic summary, unless
        the internal trigger is fired. If the trigger is fired, the summary
        will also be reported and then reset for the next accumulation.

        Args:
            manager (~pytorch_pfn_extras.training.ExtensionsManager):
                Associated manager that invoked this extension.
        """
        statistics = {}

        for link in self._links:
            for param_name, param in link.named_parameters():
                for attr_name in self._attrs:
                    for function_name, function in self._statistics.items():
                        # Get parameters as a flattened one-dimensional array
                        # since the statistics function should make no
                        # assumption about the axes
                        params = getattr(param, attr_name).flatten()
                        if (self._skip_nan_params
                            and (
                                torch.isnan(params).any())):
                            value = float('nan')
                        else:
                            value = function(params)
                        key = self.report_key_template.format(
                            prefix=self._prefix + '/' if self._prefix else '',
                            param_name=param_name,
                            attr_name=attr_name,
                            function_name=function_name
                        )
                        if (isinstance(value, torch.Tensor)
                                and value.numel() > 1):
                            # Append integer indices to the keys if the
                            # statistic function return multiple values
                            statistics.update({'{}/{}'.format(key, i): v for
                                               i, v in enumerate(value)})
                        else:
                            statistics[key] = value

        self._summary.add(statistics)

        if self._trigger(manager):
            reporting.report(self._summary.compute_mean())
            self._summary = reporting.DictSummary()  # Clear summary
    def evaluate(self):
        """Evaluates the model and returns a result dictionary.

        This method runs the evaluation loop over the validation dataset. It
        accumulates the reported values to :class:`~DictSummary` and
        returns a dictionary whose values are means computed by the summary.

        Users can override this method to customize the evaluation routine.

        Returns:
            dict: Result dictionary. This dictionary is further reported via
            :func:`~pytorch_pfn_extras.report` without specifying any observer.

        """
        iterator = self._iterators['main']

        if self.eval_hook:
            self.eval_hook(self)

        summary = reporting.DictSummary()

        progress = IterationStatus(len(iterator))
        if self._progress_bar:
            pbar = _IteratorProgressBar(iterator=progress)

        last_iter = len(iterator) - 1
        with _in_eval_mode(self._targets.values()):
            for idx, batch in enumerate(iterator):
                last_batch = idx == last_iter
                progress.current_position = idx
                observation = {}
                with reporting.report_scope(observation):
                    if isinstance(batch, (tuple, list)):
                        outs = self.eval_func(*batch)
                    elif isinstance(batch, dict):
                        outs = self.eval_func(**batch)
                    else:
                        outs = self.eval_func(batch)
                    for metric in self._metrics:
                        metric(batch, outs, last_batch)
                summary.add(observation)

                if self._progress_bar:
                    pbar.update()

        if self._progress_bar:
            pbar.close()

        return summary.compute_mean()
Exemple #3
0
    def run(self,
            loader: Iterable[Any],
            *,
            eval_len: Optional[int] = None) -> None:
        """Executes the evaluation loop.

        Args:
            loader (torch.utils.data.DataLoader):
                A data loader for evaluation.
            eval_len (int, optional):
                The number of iterations per one evaluation epoch.
        """
        # Note: setup_manager is done by the Trainer.
        self._idxs: 'queue.Queue[int]' = queue.Queue()
        self._inputs: 'queue.Queue[DictBatch]' = queue.Queue()
        self._observed: 'queue.Queue[Observation]' = queue.Queue()

        if eval_len is None:
            eval_len = len(loader)  # type: ignore[arg-type]
        self._eval_len = eval_len

        self._summary = reporting.DictSummary()
        observation: Observation = {}
        self.handler.eval_loop_begin(self)
        self._pbar = _progress_bar('validation', self._progress_bar, eval_len)
        self._update = self._pbar.__enter__()
        loader_iter = iter(loader)
        with self._profile or _nullcontext() as prof:
            with torch.no_grad():  # type: ignore[no-untyped-call]
                for idx in range(eval_len):
                    try:
                        x = next(loader_iter)
                    except StopIteration:
                        break
                    self._idxs.put(idx)
                    self._inputs.put(x)
                    self._observed.put(observation)
                    with self._reporter.scope(observation):
                        self.handler.eval_step(self, idx, x,
                                               self._complete_step)
                    # Some of the DataLoaders might need an explicit break
                    # since they could start cycling on their data
                    if (idx + 1) == eval_len:
                        break
                    if prof is not None:
                        prof.step()  # type: ignore[no-untyped-call]
        # This will report to the trainer main reporter
        self.handler.eval_loop_end(self)
        reporting.report(self._summary.compute_mean())
 def _init_summary(self):
     self._summary = reporting.DictSummary()