def on_epoch_start(self, runner: IRunner) -> None:
        """
        Set loaders for current epoch.
        If validation is not required then the first loader
        from loaders used in current epoch will be used
        as validation loader.
        Metrics from the latest epoch with true
        validation loader will be used
        in the epochs where this loader is missing.

        Args:
            runner (IRunner): current runner

        Raises:
            ValueError: if there are no loaders in epoch
        """
        epoch_num = runner.epoch
        # loaders to use in current epoch
        epoch_loaders = OrderedDict()
        for name, loader in self.loaders.items():
            period = self.loader_periods.get(name, 1)
            # ignore loaders where period - 0
            if period > 0 and epoch_num % period == 0:
                epoch_loaders[name] = loader
        if len(epoch_loaders) == 0:
            raise ValueError(f"There is no loaders in epoch {epoch_num}!")
        first_loader = next(iter(epoch_loaders.keys()))
        runner.valid_loader = (self.valid_loader if self.valid_loader
                               in epoch_loaders else first_loader)
        runner.loaders = epoch_loaders
Ejemplo n.º 2
0
    def on_epoch_end(self, runner: IRunner):
        """
        Unwraps loaders for current epoch.

        Args:
            runner (IRunner): current runner
        """
        runner.loaders = {
            key: value.origin
            if isinstance(value, BatchLimitLoaderWrapper) else value
            for key, value in runner.loaders.items()
        }
Ejemplo n.º 3
0
    def on_epoch_start(self, runner: IRunner) -> None:
        """
        Wraps loaders for current epoch.
        If number-of-batches for loader is not provided then the first batch
        from loader will be used for overfitting.

        Args:
            runner: current runner
        """
        epoch_loaders = OrderedDict()

        for name, loader in runner.loaders.items():
            num_batches = self.loader_batches.get(name, 1)
            if isinstance(num_batches, float):
                num_batches = int(len(loader) * num_batches)
            epoch_loaders[name] = BatchLimitLoaderWrapper(
                loader=loader, num_batches=num_batches,
            )

        runner.loaders = epoch_loaders