Exemplo n.º 1
0
    def from_config(cls, config: Dict[str, Any]) -> "Adam":
        """Instantiates a Adam from a configuration.

        Args:
            config: A configuration for a Adam.
                See :func:`__init__` for parameters expected in the config.

        Returns:
            A Adam instance.
        """
        # Default params
        config.setdefault("eps", 1e-8)
        config.setdefault("amsgrad", False)

        # Check if betas is a list and convert it to a tuple
        # since a JSON config can only have lists
        if "betas" in config and type(config["betas"]) == list:
            config["betas"] = tuple(config["betas"])

        assert (
            "lr" in config
        ), "Config must contain a learning rate 'lr' section for Adam optimizer"
        assert (
            "betas" in config and type(config["betas"]) == tuple
            and len(config["betas"]) == 2 and type(config["betas"][0]) == float
            and type(config["betas"][1]) == float and config["betas"][0] >= 0.0
            and config["betas"][0] < 1.0 and config["betas"][1] >= 0.0
            and config["betas"][1] < 1.0
        ), "Config must contain a tuple 'betas' in [0, 1) for Adam optimizer"
        assert "weight_decay" in config and is_pos_float(
            config["weight_decay"]
        ), "Config must contain a positive 'weight_decay' for Adam optimizer"

        lr_config = config["lr"]
        if not isinstance(lr_config, dict):
            lr_config = {"name": "constant", "value": lr_config}

        lr_config["num_epochs"] = config["num_epochs"]
        lr_scheduler = build_param_scheduler(lr_config)

        return cls(
            lr_scheduler=lr_scheduler,
            betas=config["betas"],
            eps=config["eps"],
            weight_decay=config["weight_decay"],
            amsgrad=config["amsgrad"],
        )
Exemplo n.º 2
0
    def from_config(cls, config: Dict[str, Any]) -> "RMSProp":
        """Instantiates a RMSProp from a configuration.

        Args:
            config: A configuration for a RMSProp.
                See :func:`__init__` for parameters expected in the config.

        Returns:
            A RMSProp instance.
        """
        # Default params
        config.setdefault("eps", 1e-8)
        config.setdefault("centered", False)

        assert (
            "lr" in config
        ), "Config must contain a learning rate 'lr' section for RMSProp optimizer"
        for key in ["momentum", "alpha"]:
            assert (
                key in config and config[key] >= 0.0 and config[key] < 1.0
                and type(config[key]) == float
            ), f"Config must contain a '{key}' in [0, 1) for RMSProp optimizer"
        for key in ["weight_decay", "eps"]:
            assert key in config and is_pos_float(
                config[key]
            ), f"Config must contain a positive '{key}' for RMSProp optimizer"
        assert "centered" in config and isinstance(
            config["centered"], bool
        ), "Config must contain a boolean 'centered' param for RMSProp optimizer"

        lr_config = config["lr"]
        if not isinstance(lr_config, dict):
            lr_config = {"name": "constant", "value": lr_config}

        lr_config["num_epochs"] = config["num_epochs"]
        lr_scheduler = build_param_scheduler(lr_config)

        return cls(
            lr_scheduler=lr_scheduler,
            momentum=config["momentum"],
            weight_decay=config["weight_decay"],
            alpha=config["alpha"],
            eps=config["eps"],
            centered=config["centered"],
        )
Exemplo n.º 3
0
    def from_config(cls, config: Dict[str, Any]) -> "SGD":
        """Instantiates a SGD from a configuration.

        Args:
            config: A configuration for a SGD.
                See :func:`__init__` for parameters expected in the config.

        Returns:
            A SGD instance.
        """
        # Default params
        config["nesterov"] = config.get("nesterov", False)

        assert (
            "lr" in config
        ), "Config must contain a learning rate 'lr' section for SGD optimizer"
        assert (
            "momentum" in config and config["momentum"] >= 0.0
            and config["momentum"] < 1.0 and type(config["momentum"]) == float
        ), "Config must contain a 'momentum' in [0, 1) for SGD optimizer"
        assert "nesterov" in config and isinstance(
            config["nesterov"], bool
        ), "Config must contain a boolean 'nesterov' param for SGD optimizer"
        assert "weight_decay" in config and is_pos_float(
            config["weight_decay"]
        ), "Config must contain a positive 'weight_decay' for SGD optimizer"

        lr_config = config["lr"]
        if not isinstance(lr_config, dict):
            lr_config = {"name": "constant", "value": lr_config}

        lr_config["num_epochs"] = config["num_epochs"]
        lr_scheduler = build_param_scheduler(lr_config)

        return cls(
            lr_scheduler=lr_scheduler,
            momentum=config["momentum"],
            weight_decay=config["weight_decay"],
            nesterov=config["nesterov"],
        )
Exemplo n.º 4
0
    def from_config(cls, config: Dict[str, Any]) -> "RMSPropTF":
        """Instantiates a RMSPropTF from a configuration.

        Args:
            config: A configuration for a RMSPropTF.
                See :func:`__init__` for parameters expected in the config.

        Returns:
            A RMSPropTF instance.
        """
        # Default params
        config.setdefault("lr", 0.1)
        config.setdefault("momentum", 0.0)
        config.setdefault("weight_decay", 0.0)
        config.setdefault("alpha", 0.99)
        config.setdefault("eps", 1e-8)
        config.setdefault("centered", False)

        for key in ["momentum", "alpha"]:
            assert (
                config[key] >= 0.0 and config[key] < 1.0
                and type(config[key]) == float
            ), f"Config must contain a '{key}' in [0, 1) for RMSPropTF optimizer"
        assert is_pos_float(
            config["eps"]
        ), f"Config must contain a positive 'eps' for RMSPropTF optimizer"
        assert isinstance(
            config["centered"], bool
        ), "Config must contain a boolean 'centered' param for RMSPropTF optimizer"

        return cls(
            lr=config["lr"],
            momentum=config["momentum"],
            weight_decay=config["weight_decay"],
            alpha=config["alpha"],
            eps=config["eps"],
            centered=config["centered"],
        )