def test_invalid_config(self):
        # Invalid num epochs
        config = self._get_valid_decay_config()

        bad_config = copy.deepcopy(config)
        # Invalid Base lr
        del bad_config["start_value"]
        with self.assertRaises((AssertionError, TypeError)):
            CosineParamScheduler.from_config(bad_config)

        # Invalid end_value
        bad_config["start_value"] = config["start_value"]
        del bad_config["end_value"]
        with self.assertRaises((AssertionError, TypeError)):
            CosineParamScheduler.from_config(bad_config)
    def test_scheduler_warmup_decay_match(self):
        decay_config = self._get_valid_decay_config()
        decay_scheduler = CosineParamScheduler.from_config(decay_config)

        warmup_config = copy.deepcopy(decay_config)
        # Swap start and end lr to change to warmup
        tmp = warmup_config["start_value"]
        warmup_config["start_value"] = warmup_config["end_value"]
        warmup_config["end_value"] = tmp
        warmup_scheduler = CosineParamScheduler.from_config(warmup_config)

        decay_schedule = [
            round(decay_scheduler(epoch_num / 1000), 8)
            for epoch_num in range(1, 1000)
        ]
        warmup_schedule = [
            round(warmup_scheduler(epoch_num / 1000), 8)
            for epoch_num in range(1, 1000)
        ]

        self.assertEqual(decay_schedule, list(reversed(warmup_schedule)))
    def test_scheduler_as_decay(self):
        config = self._get_valid_decay_config()

        scheduler = CosineParamScheduler.from_config(config)
        schedule = [
            round(scheduler(epoch_num / self._num_epochs), 4)
            for epoch_num in range(self._num_epochs)
        ]
        expected_schedule = [
            config["start_value"]
        ] + self._get_valid_decay_config_intermediate_values()

        self.assertEqual(schedule, expected_schedule)
    def test_scheduler_as_warmup(self):
        config = self._get_valid_decay_config()
        # Swap start and end lr to change to warmup
        tmp = config["start_value"]
        config["start_value"] = config["end_value"]
        config["end_value"] = tmp

        scheduler = CosineParamScheduler.from_config(config)
        schedule = [
            round(scheduler(epoch_num / self._num_epochs), 4)
            for epoch_num in range(self._num_epochs)
        ]
        # Schedule should be decay reversed
        expected_schedule = [config["start_value"]] + list(
            reversed(self._get_valid_decay_config_intermediate_values()))

        self.assertEqual(schedule, expected_schedule)