def test_invalid_config(self): # Invalid num epochs config = self._get_valid_decay_config() bad_config = copy.deepcopy(config) # Invalid Base lr del bad_config["start_value"] with self.assertRaises((AssertionError, TypeError)): CosineParamScheduler.from_config(bad_config) # Invalid end_value bad_config["start_value"] = config["start_value"] del bad_config["end_value"] with self.assertRaises((AssertionError, TypeError)): CosineParamScheduler.from_config(bad_config)
def test_scheduler_warmup_decay_match(self): decay_config = self._get_valid_decay_config() decay_scheduler = CosineParamScheduler.from_config(decay_config) warmup_config = copy.deepcopy(decay_config) # Swap start and end lr to change to warmup tmp = warmup_config["start_value"] warmup_config["start_value"] = warmup_config["end_value"] warmup_config["end_value"] = tmp warmup_scheduler = CosineParamScheduler.from_config(warmup_config) decay_schedule = [ round(decay_scheduler(epoch_num / 1000), 8) for epoch_num in range(1, 1000) ] warmup_schedule = [ round(warmup_scheduler(epoch_num / 1000), 8) for epoch_num in range(1, 1000) ] self.assertEqual(decay_schedule, list(reversed(warmup_schedule)))
def test_scheduler_as_decay(self): config = self._get_valid_decay_config() scheduler = CosineParamScheduler.from_config(config) schedule = [ round(scheduler(epoch_num / self._num_epochs), 4) for epoch_num in range(self._num_epochs) ] expected_schedule = [ config["start_value"] ] + self._get_valid_decay_config_intermediate_values() self.assertEqual(schedule, expected_schedule)
def test_scheduler_as_warmup(self): config = self._get_valid_decay_config() # Swap start and end lr to change to warmup tmp = config["start_value"] config["start_value"] = config["end_value"] config["end_value"] = tmp scheduler = CosineParamScheduler.from_config(config) schedule = [ round(scheduler(epoch_num / self._num_epochs), 4) for epoch_num in range(self._num_epochs) ] # Schedule should be decay reversed expected_schedule = [config["start_value"]] + list( reversed(self._get_valid_decay_config_intermediate_values())) self.assertEqual(schedule, expected_schedule)