示例#1
0
def test_multiple_optimizer_config_dicts_with_extra_keys_warns(tmpdir):
    """Test exception when multiple optimizer configuration dicts have extra keys."""
    model = BoringModel()
    optimizer1 = optim.Adam(model.parameters(), lr=0.01)
    optimizer2 = optim.Adam(model.parameters(), lr=0.01)
    lr_scheduler_config_1 = {
        "scheduler": optim.lr_scheduler.StepLR(optimizer1, 1)
    }
    lr_scheduler_config_2 = {
        "scheduler": optim.lr_scheduler.StepLR(optimizer2, 1)
    }
    optim_conf = [
        {
            "optimizer": optimizer1,
            "lr_scheduler": lr_scheduler_config_1,
            "foo": 1,
            "bar": 2
        },
        {
            "optimizer": optimizer2,
            "lr_scheduler": lr_scheduler_config_2,
            "foo": 1,
            "bar": 2
        },
    ]
    with pytest.warns(
            RuntimeWarning,
            match=
            r"Found unsupported keys in the optimizer configuration: \{.+\}"):
        _configure_optimizers(optim_conf)
def test_optimizer_config_dict_with_extra_keys_warns(tmpdir):
    """Test exception when optimizer configuration dict has extra keys."""
    model = BoringModel()
    optimizer = optim.Adam(model.parameters())
    optim_conf = {
        "optimizer": optimizer,
        "lr_scheduler": {"scheduler": optim.lr_scheduler.StepLR(optimizer, 1)},
        "foo": 1,
        "bar": 2,
    }
    with pytest.warns(RuntimeWarning, match=r"Found unsupported keys in the optimizer configuration: \{.+\}"):
        _configure_optimizers(optim_conf)
示例#3
0
    def configure_optimizers(self):
        """
        Combine architecture optimizers and user's model optimizers.
        You can overwrite :meth:`configure_architecture_optimizers` if architecture optimizers are needed in your NAS algorithm.

        For now :attr:`model` is tested against evaluators in :mod:`nni.retiarii.evaluator.pytorch.lightning`
        and it only returns 1 optimizer.
        But for extendibility, codes for other return value types are also implemented.
        """
        # pylint: disable=assignment-from-none
        arc_optimizers = self.configure_architecture_optimizers()
        if arc_optimizers is None:
            return self.model.configure_optimizers()

        if isinstance(arc_optimizers, optim.Optimizer):
            arc_optimizers = [arc_optimizers]
        self.arc_optim_count = len(arc_optimizers)

        # FIXME: this part uses non-official lightning API.
        # The return values ``frequency`` and ``monitor`` are ignored because lightning requires
        # ``len(optimizers) == len(frequency)``, and gradient backword is handled manually.
        # For data structure of variables below, please see pytorch lightning docs of ``configure_optimizers``.
        try:
            # above v1.6
            from pytorch_lightning.core.optimizer import (  # pylint: disable=import-error
                _configure_optimizers,  # type: ignore
                _configure_schedulers_automatic_opt,  # type: ignore
                _configure_schedulers_manual_opt  # type: ignore
            )
            w_optimizers, lr_schedulers, self.frequencies, monitor = \
                _configure_optimizers(self.model.configure_optimizers())  # type: ignore
            lr_schedulers = (_configure_schedulers_automatic_opt(
                lr_schedulers, monitor) if self.automatic_optimization else
                             _configure_schedulers_manual_opt(lr_schedulers))
        except ImportError:
            # under v1.5
            w_optimizers, lr_schedulers, self.frequencies, monitor = \
                self.trainer._configure_optimizers(self.model.configure_optimizers())  # type: ignore
            lr_schedulers = self.trainer._configure_schedulers(
                lr_schedulers, monitor,
                not self.automatic_optimization)  # type: ignore

        if any(sch["scheduler"].optimizer not in w_optimizers
               for sch in lr_schedulers):  # type: ignore
            raise Exception(
                "Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
            )

        # variables used to handle optimizer frequency
        self.cur_optimizer_step = 0
        self.cur_optimizer_index = 0

        return arc_optimizers + w_optimizers, lr_schedulers