Ejemplo n.º 1
0
    def test_scheduler(self):
        config = self._get_valid_config()

        # Check as warmup
        scheduler = LinearParamScheduler.from_config(config)
        schedule = [
            round(scheduler(epoch_num / self._num_epochs), 4)
            for epoch_num in range(self._num_epochs)
        ]
        expected_schedule = [config["start_value"]
                             ] + self._get_valid_intermediate()
        self.assertEqual(schedule, expected_schedule)

        # Check as decay
        tmp = config["start_value"]
        config["start_value"] = config["end_value"]
        config["end_value"] = tmp
        scheduler = LinearParamScheduler.from_config(config)
        schedule = [
            round(scheduler(epoch_num / self._num_epochs), 4)
            for epoch_num in range(self._num_epochs)
        ]
        expected_schedule = [config["start_value"]] + list(
            reversed(self._get_valid_intermediate()))
        self.assertEqual(schedule, expected_schedule)
Ejemplo n.º 2
0
    def test_invalid_config(self):
        config = self._get_valid_config()

        bad_config = copy.deepcopy(config)
        # No start lr
        del bad_config["start_value"]
        with self.assertRaises((AssertionError, TypeError)):
            LinearParamScheduler.from_config(bad_config)

        # No end lr
        bad_config["start_value"] = config["start_value"]
        del bad_config["end_value"]
        with self.assertRaises((AssertionError, TypeError)):
            LinearParamScheduler.from_config(bad_config)