def from_config(cls, config: Dict[str, Any]) -> "Adam": """Instantiates a Adam from a configuration. Args: config: A configuration for a Adam. See :func:`__init__` for parameters expected in the config. Returns: A Adam instance. """ # Default params config.setdefault("eps", 1e-8) config.setdefault("amsgrad", False) # Check if betas is a list and convert it to a tuple # since a JSON config can only have lists if "betas" in config and type(config["betas"]) == list: config["betas"] = tuple(config["betas"]) assert ( "lr" in config ), "Config must contain a learning rate 'lr' section for Adam optimizer" assert ( "betas" in config and type(config["betas"]) == tuple and len(config["betas"]) == 2 and type(config["betas"][0]) == float and type(config["betas"][1]) == float and config["betas"][0] >= 0.0 and config["betas"][0] < 1.0 and config["betas"][1] >= 0.0 and config["betas"][1] < 1.0 ), "Config must contain a tuple 'betas' in [0, 1) for Adam optimizer" assert "weight_decay" in config and is_pos_float( config["weight_decay"] ), "Config must contain a positive 'weight_decay' for Adam optimizer" lr_config = config["lr"] if not isinstance(lr_config, dict): lr_config = {"name": "constant", "value": lr_config} lr_config["num_epochs"] = config["num_epochs"] lr_scheduler = build_param_scheduler(lr_config) return cls( lr_scheduler=lr_scheduler, betas=config["betas"], eps=config["eps"], weight_decay=config["weight_decay"], amsgrad=config["amsgrad"], )
def from_config(cls, config: Dict[str, Any]) -> "RMSProp": """Instantiates a RMSProp from a configuration. Args: config: A configuration for a RMSProp. See :func:`__init__` for parameters expected in the config. Returns: A RMSProp instance. """ # Default params config.setdefault("eps", 1e-8) config.setdefault("centered", False) assert ( "lr" in config ), "Config must contain a learning rate 'lr' section for RMSProp optimizer" for key in ["momentum", "alpha"]: assert ( key in config and config[key] >= 0.0 and config[key] < 1.0 and type(config[key]) == float ), f"Config must contain a '{key}' in [0, 1) for RMSProp optimizer" for key in ["weight_decay", "eps"]: assert key in config and is_pos_float( config[key] ), f"Config must contain a positive '{key}' for RMSProp optimizer" assert "centered" in config and isinstance( config["centered"], bool ), "Config must contain a boolean 'centered' param for RMSProp optimizer" lr_config = config["lr"] if not isinstance(lr_config, dict): lr_config = {"name": "constant", "value": lr_config} lr_config["num_epochs"] = config["num_epochs"] lr_scheduler = build_param_scheduler(lr_config) return cls( lr_scheduler=lr_scheduler, momentum=config["momentum"], weight_decay=config["weight_decay"], alpha=config["alpha"], eps=config["eps"], centered=config["centered"], )
def from_config(cls, config: Dict[str, Any]) -> "SGD": """Instantiates a SGD from a configuration. Args: config: A configuration for a SGD. See :func:`__init__` for parameters expected in the config. Returns: A SGD instance. """ # Default params config["nesterov"] = config.get("nesterov", False) assert ( "lr" in config ), "Config must contain a learning rate 'lr' section for SGD optimizer" assert ( "momentum" in config and config["momentum"] >= 0.0 and config["momentum"] < 1.0 and type(config["momentum"]) == float ), "Config must contain a 'momentum' in [0, 1) for SGD optimizer" assert "nesterov" in config and isinstance( config["nesterov"], bool ), "Config must contain a boolean 'nesterov' param for SGD optimizer" assert "weight_decay" in config and is_pos_float( config["weight_decay"] ), "Config must contain a positive 'weight_decay' for SGD optimizer" lr_config = config["lr"] if not isinstance(lr_config, dict): lr_config = {"name": "constant", "value": lr_config} lr_config["num_epochs"] = config["num_epochs"] lr_scheduler = build_param_scheduler(lr_config) return cls( lr_scheduler=lr_scheduler, momentum=config["momentum"], weight_decay=config["weight_decay"], nesterov=config["nesterov"], )
def from_config(cls, config: Dict[str, Any]) -> "RMSPropTF": """Instantiates a RMSPropTF from a configuration. Args: config: A configuration for a RMSPropTF. See :func:`__init__` for parameters expected in the config. Returns: A RMSPropTF instance. """ # Default params config.setdefault("lr", 0.1) config.setdefault("momentum", 0.0) config.setdefault("weight_decay", 0.0) config.setdefault("alpha", 0.99) config.setdefault("eps", 1e-8) config.setdefault("centered", False) for key in ["momentum", "alpha"]: assert ( config[key] >= 0.0 and config[key] < 1.0 and type(config[key]) == float ), f"Config must contain a '{key}' in [0, 1) for RMSPropTF optimizer" assert is_pos_float( config["eps"] ), f"Config must contain a positive 'eps' for RMSPropTF optimizer" assert isinstance( config["centered"], bool ), "Config must contain a boolean 'centered' param for RMSPropTF optimizer" return cls( lr=config["lr"], momentum=config["momentum"], weight_decay=config["weight_decay"], alpha=config["alpha"], eps=config["eps"], centered=config["centered"], )