コード例 #1
0
 def on_stage_start(self, runner: IRunner):
     """Checks that the current stage has correct criterion."""
     criterion = runner.get_attr(
         key="criterion", inner_key=self.criterion_key
     )
     assert criterion is not None
     self._criterion = criterion
コード例 #2
0
    def on_stage_start(self, runner: IRunner) -> None:
        """Stage start hook.

        Args:
            runner: current runner
        """
        self.reduced_metric = self.reduced_metric or runner.main_metric

        scheduler = runner.get_attr(
            key="scheduler", inner_key=self.scheduler_key
        )
        assert scheduler is not None
        self._scheduler = scheduler

        if self.mode is None:
            if isinstance(scheduler, BatchScheduler):
                self.mode = "batch"
            else:
                self.mode = "epoch"

        if (
            isinstance(scheduler, OneCycleLRWithWarmup)
            and self.mode == "batch"
        ):
            scheduler.reset()
        assert self.mode is not None
コード例 #3
0
    def on_stage_start(self, runner: IRunner) -> None:
        """Checks that the current stage has correct optimizer.

        Args:
            runner(IRunner): current runner
        """
        self._optimizer = runner.get_attr(key="optimizer",
                                          inner_key=self.optimizer_key)
        assert self._optimizer is not None
コード例 #4
0
ファイル: scheduler.py プロジェクト: valayDave/catalyst
    def on_stage_start(self, runner: IRunner) -> None:
        """Stage start hook.

        Args:
            runner (IRunner): current runner
        """
        optimizer = runner.get_attr(key="optimizer",
                                    inner_key=self.optimizer_key)
        assert optimizer is not None
        self._optimizer = optimizer
        self.init_lr = optimizer.defaults["lr"]
コード例 #5
0
    def on_stage_start(self, runner: IRunner) -> None:
        """Checks that the current stage has correct optimizer.

        Args:
            runner(IRunner): current runner
        """
        from torch.cuda.amp import GradScaler

        self._optimizer = runner.get_attr(key="optimizer",
                                          inner_key=self.optimizer_key)
        self.scaler = GradScaler()
        assert self._optimizer is not None
コード例 #6
0
    def on_stage_start(self, runner: IRunner) -> None:
        """Checks that the current stage has correct optimizer.

        Args:
            runner(IRunner): current runner
        """
        self._optimizer = runner.get_attr(key="optimizer",
                                          inner_key=self.optimizer_key)
        # device based optimization step
        if runner.device.type == "xla":
            self._optimizer_step_fn = self._optimizer_step_tpu
        else:
            self._optimizer_step_fn = self._optimizer_step

        assert self._optimizer is not None