def test_scheduler_with_mixed_types(self): config = self._get_valid_mixed_config() scheduler_0 = build_param_scheduler(config["schedulers"][0]) scheduler_1 = build_param_scheduler(config["schedulers"][1]) # Check scaled config["interval_scaling"] = ["rescaled", "rescaled"] scheduler = CompositeParamScheduler.from_config(config) scaled_schedule = [ round(scheduler(epoch_num / self._num_epochs), 4) for epoch_num in range(self._num_epochs) ] expected_schedule = [ round(scheduler_0(epoch_num / self._num_epochs), 4) for epoch_num in range(0, self._num_epochs, 2) ] + [ round(scheduler_1(epoch_num / self._num_epochs), 4) for epoch_num in range(0, self._num_epochs, 2) ] self.assertEqual(scaled_schedule, expected_schedule) # Check fixed config["interval_scaling"] = ["fixed", "fixed"] scheduler = CompositeParamScheduler.from_config(config) fixed_schedule = [ round(scheduler(epoch_num / self._num_epochs), 4) for epoch_num in range(self._num_epochs) ] expected_schedule = [ round(scheduler_0(epoch_num / self._num_epochs), 4) for epoch_num in range(0, int(self._num_epochs / 2)) ] + [ round(scheduler_1(epoch_num / self._num_epochs), 4) for epoch_num in range(int(self._num_epochs / 2), self._num_epochs) ] self.assertEqual(fixed_schedule, expected_schedule) # Check that default is rescaled del config["interval_scaling"] scheduler = CompositeParamScheduler.from_config(config) schedule = [ round(scheduler(epoch_num / self._num_epochs), 4) for epoch_num in range(self._num_epochs) ] self.assertEqual(scaled_schedule, schedule) # Check warmup of rescaled then fixed config["interval_scaling"] = ["rescaled", "fixed"] scheduler = CompositeParamScheduler.from_config(config) fixed_schedule = [ round(scheduler(epoch_num / self._num_epochs), 4) for epoch_num in range(self._num_epochs) ] expected_schedule = [ round(scheduler_0(epoch_num / self._num_epochs), 4) for epoch_num in range(0, int(self._num_epochs), 2) ] + [ round(scheduler_1(epoch_num / self._num_epochs), 4) for epoch_num in range(int(self._num_epochs / 2), self._num_epochs) ] self.assertEqual(fixed_schedule, expected_schedule)
def test_linear_scheduler_no_gaps(self): config = self._get_valid_linear_config() # Check rescaled scheduler = CompositeParamScheduler.from_config(config) schedule = [ scheduler(epoch_num / self._num_epochs) for epoch_num in range(self._num_epochs) ] expected_schedule = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] self.assertEqual(expected_schedule, schedule) # Check fixed composition gives same result as only 1 scheduler config["schedulers"][1] = config["schedulers"][0] config["interval_scaling"] = ["fixed", "fixed"] scheduler = CompositeParamScheduler.from_config(config) linear_scheduler = build_param_scheduler(config["schedulers"][0]) schedule = [ scheduler(epoch_num / self._num_epochs) for epoch_num in range(self._num_epochs) ] expected_schedule = [ linear_scheduler(epoch_num / self._num_epochs) for epoch_num in range(self._num_epochs) ] self.assertEqual(expected_schedule, schedule)
def test_build_composite_scheduler(self): config = self._get_valid_mixed_config() scheduler = build_param_scheduler(config) self.assertTrue(isinstance(scheduler, CompositeParamScheduler)) schedulers = [ build_param_scheduler(scheduler_config) for scheduler_config in config["schedulers"] ] composite = CompositeParamScheduler( schedulers=schedulers, lengths=config["lengths"], update_interval=UpdateInterval.EPOCH, interval_scaling=[IntervalScaling.RESCALED, IntervalScaling.FIXED], ) self.assertTrue(isinstance(composite, CompositeParamScheduler))
def test_scheduler_update_interval(self): config = self._get_valid_mixed_config() # Check default scheduler = CompositeParamScheduler.from_config(config) self.assertEqual(scheduler.update_interval, UpdateInterval.STEP) # Check step step_config = copy.deepcopy(config) step_config["update_interval"] = "step" scheduler = build_param_scheduler(step_config) self.assertEqual(scheduler.update_interval, UpdateInterval.STEP) # Check epoch epoch_config = copy.deepcopy(config) epoch_config["update_interval"] = "epoch" scheduler = build_param_scheduler(epoch_config) self.assertEqual(scheduler.update_interval, UpdateInterval.EPOCH)
def from_config(cls, config: Dict[str, Any]) -> "Adam": """Instantiates a Adam from a configuration. Args: config: A configuration for a Adam. See :func:`__init__` for parameters expected in the config. Returns: A Adam instance. """ # Default params config.setdefault("eps", 1e-8) config.setdefault("amsgrad", False) # Check if betas is a list and convert it to a tuple # since a JSON config can only have lists if "betas" in config and type(config["betas"]) == list: config["betas"] = tuple(config["betas"]) assert ( "lr" in config ), "Config must contain a learning rate 'lr' section for Adam optimizer" assert ( "betas" in config and type(config["betas"]) == tuple and len(config["betas"]) == 2 and type(config["betas"][0]) == float and type(config["betas"][1]) == float and config["betas"][0] >= 0.0 and config["betas"][0] < 1.0 and config["betas"][1] >= 0.0 and config["betas"][1] < 1.0 ), "Config must contain a tuple 'betas' in [0, 1) for Adam optimizer" assert "weight_decay" in config and is_pos_float( config["weight_decay"] ), "Config must contain a positive 'weight_decay' for Adam optimizer" lr_config = config["lr"] if not isinstance(lr_config, dict): lr_config = {"name": "constant", "value": lr_config} lr_config["num_epochs"] = config["num_epochs"] lr_scheduler = build_param_scheduler(lr_config) return cls( lr_scheduler=lr_scheduler, betas=config["betas"], eps=config["eps"], weight_decay=config["weight_decay"], amsgrad=config["amsgrad"], )
def from_config(cls, config: Dict[str, Any]) -> "RMSProp": """Instantiates a RMSProp from a configuration. Args: config: A configuration for a RMSProp. See :func:`__init__` for parameters expected in the config. Returns: A RMSProp instance. """ # Default params config.setdefault("eps", 1e-8) config.setdefault("centered", False) assert ( "lr" in config ), "Config must contain a learning rate 'lr' section for RMSProp optimizer" for key in ["momentum", "alpha"]: assert ( key in config and config[key] >= 0.0 and config[key] < 1.0 and type(config[key]) == float ), f"Config must contain a '{key}' in [0, 1) for RMSProp optimizer" for key in ["weight_decay", "eps"]: assert key in config and is_pos_float( config[key] ), f"Config must contain a positive '{key}' for RMSProp optimizer" assert "centered" in config and isinstance( config["centered"], bool ), "Config must contain a boolean 'centered' param for RMSProp optimizer" lr_config = config["lr"] if not isinstance(lr_config, dict): lr_config = {"name": "constant", "value": lr_config} lr_config["num_epochs"] = config["num_epochs"] lr_scheduler = build_param_scheduler(lr_config) return cls( lr_scheduler=lr_scheduler, momentum=config["momentum"], weight_decay=config["weight_decay"], alpha=config["alpha"], eps=config["eps"], centered=config["centered"], )
def from_config(cls, config: Dict[str, Any]) -> "SGD": """Instantiates a SGD from a configuration. Args: config: A configuration for a SGD. See :func:`__init__` for parameters expected in the config. Returns: A SGD instance. """ # Default params config["nesterov"] = config.get("nesterov", False) assert ( "lr" in config ), "Config must contain a learning rate 'lr' section for SGD optimizer" assert ( "momentum" in config and config["momentum"] >= 0.0 and config["momentum"] < 1.0 and type(config["momentum"]) == float ), "Config must contain a 'momentum' in [0, 1) for SGD optimizer" assert "nesterov" in config and isinstance( config["nesterov"], bool ), "Config must contain a boolean 'nesterov' param for SGD optimizer" assert "weight_decay" in config and is_pos_float( config["weight_decay"] ), "Config must contain a positive 'weight_decay' for SGD optimizer" lr_config = config["lr"] if not isinstance(lr_config, dict): lr_config = {"name": "constant", "value": lr_config} lr_config["num_epochs"] = config["num_epochs"] lr_scheduler = build_param_scheduler(lr_config) return cls( lr_scheduler=lr_scheduler, momentum=config["momentum"], weight_decay=config["weight_decay"], nesterov=config["nesterov"], )
def test_build_polynomial_scheduler(self): config = self._get_valid_config() scheduler = build_param_scheduler(config) self.assertTrue(isinstance(scheduler, PolynomialDecayParamScheduler))
def test_build_non_equi_step_scheduler(self): config = self._get_valid_config() scheduler = build_param_scheduler(config) self.assertTrue(isinstance(scheduler, MultiStepParamScheduler))
def test_build_linear_scheduler(self): config = self._get_valid_config() scheduler = build_param_scheduler(config) self.assertTrue(isinstance(scheduler, LinearParamScheduler))
def test_build_composite_scheduler(self): config = self._get_valid_mixed_config() scheduler = build_param_scheduler(config) self.assertTrue(isinstance(scheduler, CompositeParamScheduler))
def test_build_cosine_scheduler(self): config = self._get_valid_decay_config() scheduler = build_param_scheduler(config) self.assertTrue(isinstance(scheduler, CosineParamScheduler))
def test_build_step_with_fixed_gamma_scheduler(self): config = self._get_valid_config() scheduler = build_param_scheduler(config) self.assertTrue(isinstance(scheduler, StepWithFixedGammaParamScheduler))