Ejemplo n.º 1
0
 def on_stage_start(self, runner: IRunner):
     """Checks that the current stage has correct criterion."""
     criterion = runner.get_attr(
         key="criterion", inner_key=self.criterion_key
     )
     assert criterion is not None
     self._criterion = criterion
Ejemplo n.º 2
0
    def on_stage_start(self, runner: IRunner) -> None:
        """Stage start hook.

        Args:
            runner: current runner
        """
        self.reduced_metric = self.reduced_metric or runner.main_metric

        scheduler = runner.get_attr(
            key="scheduler", inner_key=self.scheduler_key
        )
        assert scheduler is not None
        self._scheduler = scheduler

        if self.mode is None:
            if isinstance(scheduler, BatchScheduler):
                self.mode = "batch"
            else:
                self.mode = "epoch"

        if (
            isinstance(scheduler, OneCycleLRWithWarmup)
            and self.mode == "batch"
        ):
            scheduler.reset()
        assert self.mode is not None
Ejemplo n.º 3
0
    def on_stage_start(self, runner: IRunner) -> None:
        """Checks that the current stage has correct optimizer.

        Args:
            runner(IRunner): current runner
        """
        self._optimizer = runner.get_attr(key="optimizer",
                                          inner_key=self.optimizer_key)
        assert self._optimizer is not None
Ejemplo n.º 4
0
    def on_stage_start(self, runner: IRunner) -> None:
        """Stage start hook.

        Args:
            runner (IRunner): current runner
        """
        optimizer = runner.get_attr(key="optimizer",
                                    inner_key=self.optimizer_key)
        assert optimizer is not None
        self._optimizer = optimizer
        self.init_lr = optimizer.defaults["lr"]
Ejemplo n.º 5
0
    def on_stage_start(self, runner: IRunner) -> None:
        """Checks that the current stage has correct optimizer.

        Args:
            runner(IRunner): current runner
        """
        from torch.cuda.amp import GradScaler

        self._optimizer = runner.get_attr(key="optimizer",
                                          inner_key=self.optimizer_key)
        self.scaler = GradScaler()
        assert self._optimizer is not None
Ejemplo n.º 6
0
    def on_stage_start(self, runner: IRunner) -> None:
        """Checks that the current stage has correct optimizer.

        Args:
            runner(IRunner): current runner
        """
        self._optimizer = runner.get_attr(key="optimizer",
                                          inner_key=self.optimizer_key)
        # device based optimization step
        if runner.device.type == "xla":
            self._optimizer_step_fn = self._optimizer_step_tpu
        else:
            self._optimizer_step_fn = self._optimizer_step

        assert self._optimizer is not None