Ejemplo n.º 1
0
    def test_sched_config_parse_from_cls(self):
        """Test that we can parse a scheduler from a class"""
        model = TempModel()
        opt_cls = get_optimizer("novograd")
        opt = opt_cls(model.parameters(), lr=self.INITIAL_LR)

        basic_sched_config = {
            "_target_": "mridc.core.conf.schedulers.CosineAnnealingParams",
            "params": {
                "min_lr": 0.1
            },
            "max_steps": self.MAX_STEPS,
        }
        scheduler_setup = optim.lr_scheduler.prepare_lr_scheduler(
            opt, basic_sched_config)
        if not isinstance(scheduler_setup["scheduler"],
                          optim.lr_scheduler.CosineAnnealing):
            raise AssertionError

        dict_config = omegaconf.OmegaConf.create(basic_sched_config)
        scheduler_setup = optim.lr_scheduler.prepare_lr_scheduler(
            opt, dict_config)
        if not isinstance(scheduler_setup["scheduler"],
                          optim.lr_scheduler.CosineAnnealing):
            raise AssertionError
Ejemplo n.º 2
0
    def test_SquareAnnealing(self):
        """Test SquareAnnealing"""
        model = TempModel()
        opt_cls = get_optimizer("novograd")
        opt = opt_cls(model.parameters(), lr=self.INITIAL_LR)

        # No warmup case
        policy = optim.lr_scheduler.SquareAnnealing(opt,
                                                    max_steps=self.MAX_STEPS,
                                                    min_lr=self.MIN_LR)
        initial_lr = policy.get_last_lr()[0]

        if initial_lr != self.INITIAL_LR:
            raise AssertionError

        for _ in range(self.MAX_STEPS):
            if policy.get_last_lr()[0] > self.INITIAL_LR:
                raise AssertionError
            opt.step()
            policy.step()

        policy.step()
        final_lr = policy.get_last_lr()[0]

        if final_lr != self.MIN_LR:
            raise AssertionError

        # Warmup steps available
        policy = optim.lr_scheduler.SquareAnnealing(opt,
                                                    warmup_steps=5,
                                                    max_steps=self.MAX_STEPS,
                                                    min_lr=self.MIN_LR)
        initial_lr = policy.get_last_lr()[0]

        if initial_lr >= self.INITIAL_LR:
            raise AssertionError

        for i in range(self.MAX_STEPS):
            if i <= 5:
                if policy.get_last_lr()[0] > self.INITIAL_LR:
                    raise AssertionError
            elif policy.get_last_lr()[0] >= self.INITIAL_LR:
                raise AssertionError

            opt.step()
            policy.step()

        policy.step()
        final_lr = policy.get_last_lr()[0]

        if final_lr != self.MIN_LR:
            raise AssertionError
Ejemplo n.º 3
0
    def test_register_optimizer(self):
        """Test that we can register a new optimizer"""
        class TempOpt(torch.optim.SGD):
            """A dummy optimizer"""

        class TempOptParams(optimizers.SGDParams):
            """A dummy optimizer params"""

        register_optimizer("TempOpt", TempOpt, TempOptParams)

        model = TempModel()
        opt_cls = get_optimizer("TempOpt")
        opt = opt_cls(model.parameters(), lr=self.INITIAL_LR)

        if not isinstance(opt, TempOpt):
            raise AssertionError
Ejemplo n.º 4
0
    def test_get_optimizer(self):
        """Test that the optimizer is correctly created"""
        model = TempModel()

        for opt_name in AVAILABLE_OPTIMIZERS:
            if opt_name == "fused_adam" and not torch.cuda.is_available():
                continue
            opt_cls = get_optimizer(opt_name)
            if opt_name == "adafactor":
                # Adafactor's default mode uses relative_step without any lr.
                opt = opt_cls(model.parameters())
            else:
                opt = opt_cls(model.parameters(), lr=self.INITIAL_LR)

            if not isinstance(opt, AVAILABLE_OPTIMIZERS[opt_name]):
                raise AssertionError
Ejemplo n.º 5
0
    def test_sched_config_parse_simple(self):
        """Test that scheduler config is parsed correctly"""
        model = TempModel()
        opt_cls = get_optimizer("novograd")
        opt = opt_cls(model.parameters(), lr=self.INITIAL_LR)

        basic_sched_config = {"name": "CosineAnnealing", "max_steps": 10}
        scheduler_setup = optim.lr_scheduler.prepare_lr_scheduler(
            opt, basic_sched_config)
        if not isinstance(scheduler_setup["scheduler"],
                          optim.lr_scheduler.CosineAnnealing):
            raise AssertionError

        dict_config = omegaconf.OmegaConf.create(basic_sched_config)
        scheduler_setup = optim.lr_scheduler.prepare_lr_scheduler(
            opt, dict_config)
        if not isinstance(scheduler_setup["scheduler"],
                          optim.lr_scheduler.CosineAnnealing):
            raise AssertionError
Ejemplo n.º 6
0
    def test_register_scheduler(self):
        """Test registering a new scheduler"""
        class TempSched(optim.lr_scheduler.CosineAnnealing):
            """Temporary scheduler class."""

        class TempSchedParams(CosineAnnealingParams):
            """Temporary scheduler class."""

        optim.lr_scheduler.register_scheduler("TempSched", TempSched,
                                              TempSchedParams)

        model = TempModel()
        opt_cls = get_optimizer("novograd")
        opt = opt_cls(model.parameters(), lr=self.INITIAL_LR)
        sched_cls = optim.lr_scheduler.get_scheduler("TempSched")
        sched = sched_cls(opt, max_steps=self.MAX_STEPS)

        if not isinstance(sched, TempSched):
            raise AssertionError
Ejemplo n.º 7
0
    def test_CosineAnnealing_with_noop_steps(self):
        """Test CosineAnnealing with noop steps."""
        model = TempModel()
        opt_cls = get_optimizer("novograd")
        opt = opt_cls(model.parameters(), lr=self.INITIAL_LR)

        # No warmup case
        policy = optim.lr_scheduler.CosineAnnealing(opt,
                                                    max_steps=self.MAX_STEPS,
                                                    min_lr=self.MIN_LR)
        initial_lr = policy.get_last_lr()[0]

        if initial_lr != self.INITIAL_LR:
            raise AssertionError

        update_steps = 0
        for i in range(self.MAX_STEPS):
            if policy.get_last_lr()[0] > self.INITIAL_LR:
                raise AssertionError
            opt.step()
            policy.step()

            # Perform a No-Op for scheduler every 2 steps
            if i % 2 == 0:
                policy.last_epoch -= 1
            else:
                update_steps += 1

        policy.step()
        update_steps += 1

        if update_steps >= self.MAX_STEPS:
            raise AssertionError

        final_lr = policy.get_last_lr()[0]
        if final_lr <= self.MIN_LR:
            raise AssertionError

        # update step = true number of updates performed after some number of skipped steps
        true_end_lr = policy._get_lr(step=update_steps)[0]
        if final_lr != true_end_lr:
            raise AssertionError
Ejemplo n.º 8
0
    def test_PolynomialHoldDecayAnnealing(self):
        """Test PolynomialHoldDecayAnnealing"""
        model = TempModel()
        opt_cls = get_optimizer("novograd")
        opt = opt_cls(model.parameters(), lr=self.INITIAL_LR)

        # No warmup case
        policy = optim.lr_scheduler.PolynomialHoldDecayAnnealing(
            opt, power=2, max_steps=self.MAX_STEPS, min_lr=self.MIN_LR)
        initial_lr = policy.get_last_lr()[0]

        if initial_lr != self.INITIAL_LR:
            raise AssertionError

        for _ in range(self.MAX_STEPS):
            if policy.get_last_lr()[0] > self.INITIAL_LR:
                raise AssertionError
            opt.step()
            policy.step()

        policy.step()
        final_lr = policy.get_last_lr()[0]

        if final_lr <= self.MIN_LR:
            raise AssertionError

        # Warmup steps available
        policy = optim.lr_scheduler.PolynomialHoldDecayAnnealing(
            opt,
            power=2,
            warmup_steps=5,
            max_steps=self.MAX_STEPS,
            min_lr=self.MIN_LR)
        initial_lr = policy.get_last_lr()[0]

        if initial_lr >= self.INITIAL_LR:
            raise AssertionError

        for _ in range(self.MAX_STEPS):
            if policy.get_last_lr()[0] > self.INITIAL_LR:
                raise AssertionError

            opt.step()
            policy.step()

        policy.step()
        final_lr = policy.get_last_lr()[0]

        if final_lr < self.MIN_LR:
            raise AssertionError

        # Warmup + Hold steps available
        policy = optim.lr_scheduler.PolynomialHoldDecayAnnealing(
            opt,
            warmup_steps=5,
            hold_steps=3,
            max_steps=self.MAX_STEPS,
            min_lr=self.MIN_LR,
            power=2)
        initial_lr = policy.get_last_lr()[0]

        if initial_lr >= self.INITIAL_LR:
            raise AssertionError

        for i in range(self.MAX_STEPS):
            if i <= 4:
                if policy.get_last_lr()[0] > self.INITIAL_LR:
                    raise AssertionError
            elif i <= 8:
                if policy.get_last_lr()[0] < self.INITIAL_LR:
                    raise AssertionError
            elif policy.get_last_lr()[0] > self.INITIAL_LR:
                raise AssertionError
            opt.step()
            policy.step()

        policy.step()
        final_lr = policy.get_last_lr()[0]

        if final_lr < self.MIN_LR:
            raise AssertionError
Ejemplo n.º 9
0
    def test_WarmupAnnealing(self):
        """Test that the warmup annealing policy works as expected."""
        model = TempModel()
        opt_cls = get_optimizer("novograd")
        opt = opt_cls(model.parameters(), lr=self.INITIAL_LR)

        # No warmup case
        policy = optim.lr_scheduler.WarmupAnnealing(opt,
                                                    max_steps=self.MAX_STEPS,
                                                    min_lr=self.MIN_LR)
        initial_lr = policy.get_last_lr()[0]

        if initial_lr != self.INITIAL_LR:
            raise AssertionError

        for _ in range(self.MAX_STEPS):
            if policy.get_last_lr()[0] > self.INITIAL_LR:
                raise AssertionError
            opt.step()
            policy.step()

        policy.step()
        final_lr = policy.get_last_lr()[0]

        if final_lr < self.MIN_LR:
            raise AssertionError

        # Warmup steps available
        policy = optim.lr_scheduler.WarmupAnnealing(opt,
                                                    warmup_steps=5,
                                                    max_steps=self.MAX_STEPS,
                                                    min_lr=self.MIN_LR)
        initial_lr = policy.get_last_lr()[0]

        if initial_lr >= self.INITIAL_LR:
            raise AssertionError

        for i in range(self.MAX_STEPS):
            if i <= 5:
                if policy.get_last_lr()[0] > self.INITIAL_LR:
                    raise AssertionError
            elif policy.get_last_lr()[0] >= self.INITIAL_LR:
                raise AssertionError

            opt.step()
            policy.step()

        policy.step()
        final_lr = policy.get_last_lr()[0]

        if final_lr != self.MIN_LR:
            raise AssertionError

        # Warmup + Hold steps available
        policy = optim.lr_scheduler.WarmupHoldPolicy(opt,
                                                     warmup_steps=5,
                                                     hold_steps=3,
                                                     max_steps=self.MAX_STEPS,
                                                     min_lr=self.MIN_LR)
        initial_lr = policy.get_last_lr()[0]

        if initial_lr >= self.INITIAL_LR:
            raise AssertionError

        for i in range(self.MAX_STEPS):
            if i <= 4:
                if policy.get_last_lr()[0] > self.INITIAL_LR:
                    raise AssertionError
            elif policy.get_last_lr()[0] != self.INITIAL_LR:
                raise AssertionError
            opt.step()
            policy.step()

        policy.step()
        final_lr = policy.get_last_lr()[0]

        if final_lr < self.MIN_LR:
            raise AssertionError