def test_linear_scheduler_no_gaps(self): config = self._get_valid_linear_config() # Check rescaled scheduler = CompositeParamScheduler.from_config(config) schedule = [ scheduler(epoch_num / self._num_epochs) for epoch_num in range(self._num_epochs) ] expected_schedule = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] self.assertEqual(expected_schedule, schedule) # Check fixed composition gives same result as only 1 scheduler config["schedulers"][1] = config["schedulers"][0] config["interval_scaling"] = ["fixed", "fixed"] scheduler = CompositeParamScheduler.from_config(config) linear_scheduler = build_param_scheduler(config["schedulers"][0]) schedule = [ scheduler(epoch_num / self._num_epochs) for epoch_num in range(self._num_epochs) ] expected_schedule = [ linear_scheduler(epoch_num / self._num_epochs) for epoch_num in range(self._num_epochs) ] self.assertEqual(expected_schedule, schedule)
def test_scheduler_with_mixed_types(self): config = self._get_valid_mixed_config() scheduler_0 = build_param_scheduler(config["schedulers"][0]) scheduler_1 = build_param_scheduler(config["schedulers"][1]) # Check scaled config["interval_scaling"] = ["rescaled", "rescaled"] scheduler = CompositeParamScheduler.from_config(config) scaled_schedule = [ round(scheduler(epoch_num / self._num_epochs), 4) for epoch_num in range(self._num_epochs) ] expected_schedule = [ round(scheduler_0(epoch_num / self._num_epochs), 4) for epoch_num in range(0, self._num_epochs, 2) ] + [ round(scheduler_1(epoch_num / self._num_epochs), 4) for epoch_num in range(0, self._num_epochs, 2) ] self.assertEqual(scaled_schedule, expected_schedule) # Check fixed config["interval_scaling"] = ["fixed", "fixed"] scheduler = CompositeParamScheduler.from_config(config) fixed_schedule = [ round(scheduler(epoch_num / self._num_epochs), 4) for epoch_num in range(self._num_epochs) ] expected_schedule = [ round(scheduler_0(epoch_num / self._num_epochs), 4) for epoch_num in range(0, int(self._num_epochs / 2)) ] + [ round(scheduler_1(epoch_num / self._num_epochs), 4) for epoch_num in range(int(self._num_epochs / 2), self._num_epochs) ] self.assertEqual(fixed_schedule, expected_schedule) # Check that default is rescaled del config["interval_scaling"] scheduler = CompositeParamScheduler.from_config(config) schedule = [ round(scheduler(epoch_num / self._num_epochs), 4) for epoch_num in range(self._num_epochs) ] self.assertEqual(scaled_schedule, schedule) # Check warmup of rescaled then fixed config["interval_scaling"] = ["rescaled", "fixed"] scheduler = CompositeParamScheduler.from_config(config) fixed_schedule = [ round(scheduler(epoch_num / self._num_epochs), 4) for epoch_num in range(self._num_epochs) ] expected_schedule = [ round(scheduler_0(epoch_num / self._num_epochs), 4) for epoch_num in range(0, int(self._num_epochs), 2) ] + [ round(scheduler_1(epoch_num / self._num_epochs), 4) for epoch_num in range(int(self._num_epochs / 2), self._num_epochs) ] self.assertEqual(fixed_schedule, expected_schedule)
def test_scheduler_lengths_within_epsilon_of_one(self): config = self._get_lengths_sum_less_one_config() scheduler = CompositeParamScheduler.from_config(config) schedule = [ scheduler(epoch_num / self._num_epochs) for epoch_num in range(self._num_epochs) ] expected_schedule = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2] self.assertEqual(schedule, expected_schedule)
def test_long_scheduler(self): config = self._get_valid_long_config() scheduler = CompositeParamScheduler.from_config(config) schedule = [ scheduler(epoch_num / self._num_epochs) for epoch_num in range(self._num_epochs) ] expected_schedule = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.4, 0.4, 0.4] self.assertEqual(schedule, expected_schedule)
def test_scheduler_update_interval(self): config = self._get_valid_mixed_config() # Check default scheduler = CompositeParamScheduler.from_config(config) self.assertEqual(scheduler.update_interval, UpdateInterval.STEP) # Check step step_config = copy.deepcopy(config) step_config["update_interval"] = "step" scheduler = build_param_scheduler(step_config) self.assertEqual(scheduler.update_interval, UpdateInterval.STEP) # Check epoch epoch_config = copy.deepcopy(config) epoch_config["update_interval"] = "epoch" scheduler = build_param_scheduler(epoch_config) self.assertEqual(scheduler.update_interval, UpdateInterval.EPOCH)
def test_invalid_config(self): config = self._get_valid_mixed_config() bad_config = copy.deepcopy(config) # No schedulers bad_config["schedulers"] = [] bad_config["lengths"] = [] with self.assertRaises(AssertionError): CompositeParamScheduler.from_config(bad_config) # Size of schedulers and lengths doesn't match bad_config["schedulers"] = copy.deepcopy(config["schedulers"]) bad_config["lengths"] = copy.deepcopy(config["lengths"]) bad_config["schedulers"].append(bad_config["schedulers"][-1]) with self.assertRaises(AssertionError): CompositeParamScheduler.from_config(bad_config) # Sum of lengths < 1 bad_config["schedulers"] = copy.deepcopy(config["schedulers"]) bad_config["lengths"][-1] -= 0.1 with self.assertRaises(AssertionError): CompositeParamScheduler.from_config(bad_config) # Sum of lengths > 1 bad_config["lengths"] = copy.deepcopy(config["lengths"]) bad_config["lengths"][-1] += 0.1 with self.assertRaises(AssertionError): CompositeParamScheduler.from_config(bad_config) # Bad value for update_interval bad_config["lengths"] = copy.deepcopy(config["lengths"]) bad_config["update_interval"] = "epochs" with self.assertRaises(Exception): CompositeParamScheduler.from_config(bad_config) # Bad value for composition_mode del bad_config["update_interval"] bad_config["interval_scaling"] = ["rescaled", "rescaleds"] with self.assertRaises(AssertionError): CompositeParamScheduler.from_config(bad_config) # Wrong number composition modes del bad_config["interval_scaling"] bad_config["interval_scaling"] = ["rescaled"] with self.assertRaises(AssertionError): CompositeParamScheduler.from_config(bad_config) # Missing required parameters del bad_config["interval_scaling"] bad_config["lengths"] = config["lengths"] del bad_config["lengths"] with self.assertRaises(AssertionError): CompositeParamScheduler.from_config(bad_config) bad_config["lengths"] = config["lengths"] del bad_config["schedulers"] with self.assertRaises(AssertionError): CompositeParamScheduler.from_config(bad_config)