def test_scheduler_with_mixed_types(self):
        config = self._get_valid_mixed_config()
        scheduler_0 = build_param_scheduler(config["schedulers"][0])
        scheduler_1 = build_param_scheduler(config["schedulers"][1])

        # Check scaled
        config["interval_scaling"] = ["rescaled", "rescaled"]
        scheduler = CompositeParamScheduler.from_config(config)
        scaled_schedule = [
            round(scheduler(epoch_num / self._num_epochs), 4)
            for epoch_num in range(self._num_epochs)
        ]
        expected_schedule = [
            round(scheduler_0(epoch_num / self._num_epochs), 4)
            for epoch_num in range(0, self._num_epochs, 2)
        ] + [
            round(scheduler_1(epoch_num / self._num_epochs), 4)
            for epoch_num in range(0, self._num_epochs, 2)
        ]
        self.assertEqual(scaled_schedule, expected_schedule)

        # Check fixed
        config["interval_scaling"] = ["fixed", "fixed"]
        scheduler = CompositeParamScheduler.from_config(config)
        fixed_schedule = [
            round(scheduler(epoch_num / self._num_epochs), 4)
            for epoch_num in range(self._num_epochs)
        ]
        expected_schedule = [
            round(scheduler_0(epoch_num / self._num_epochs), 4)
            for epoch_num in range(0, int(self._num_epochs / 2))
        ] + [
            round(scheduler_1(epoch_num / self._num_epochs), 4)
            for epoch_num in range(int(self._num_epochs / 2), self._num_epochs)
        ]
        self.assertEqual(fixed_schedule, expected_schedule)

        # Check that default is rescaled
        del config["interval_scaling"]
        scheduler = CompositeParamScheduler.from_config(config)
        schedule = [
            round(scheduler(epoch_num / self._num_epochs), 4)
            for epoch_num in range(self._num_epochs)
        ]
        self.assertEqual(scaled_schedule, schedule)
        # Check warmup of rescaled then fixed
        config["interval_scaling"] = ["rescaled", "fixed"]
        scheduler = CompositeParamScheduler.from_config(config)
        fixed_schedule = [
            round(scheduler(epoch_num / self._num_epochs), 4)
            for epoch_num in range(self._num_epochs)
        ]
        expected_schedule = [
            round(scheduler_0(epoch_num / self._num_epochs), 4)
            for epoch_num in range(0, int(self._num_epochs), 2)
        ] + [
            round(scheduler_1(epoch_num / self._num_epochs), 4)
            for epoch_num in range(int(self._num_epochs / 2), self._num_epochs)
        ]
        self.assertEqual(fixed_schedule, expected_schedule)
    def test_linear_scheduler_no_gaps(self):
        config = self._get_valid_linear_config()

        # Check rescaled
        scheduler = CompositeParamScheduler.from_config(config)
        schedule = [
            scheduler(epoch_num / self._num_epochs)
            for epoch_num in range(self._num_epochs)
        ]
        expected_schedule = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        self.assertEqual(expected_schedule, schedule)

        # Check fixed composition gives same result as only 1 scheduler
        config["schedulers"][1] = config["schedulers"][0]
        config["interval_scaling"] = ["fixed", "fixed"]
        scheduler = CompositeParamScheduler.from_config(config)
        linear_scheduler = build_param_scheduler(config["schedulers"][0])
        schedule = [
            scheduler(epoch_num / self._num_epochs)
            for epoch_num in range(self._num_epochs)
        ]
        expected_schedule = [
            linear_scheduler(epoch_num / self._num_epochs)
            for epoch_num in range(self._num_epochs)
        ]
        self.assertEqual(expected_schedule, schedule)
    def test_build_composite_scheduler(self):
        config = self._get_valid_mixed_config()
        scheduler = build_param_scheduler(config)
        self.assertTrue(isinstance(scheduler, CompositeParamScheduler))

        schedulers = [
            build_param_scheduler(scheduler_config)
            for scheduler_config in config["schedulers"]
        ]
        composite = CompositeParamScheduler(
            schedulers=schedulers,
            lengths=config["lengths"],
            update_interval=UpdateInterval.EPOCH,
            interval_scaling=[IntervalScaling.RESCALED, IntervalScaling.FIXED],
        )
        self.assertTrue(isinstance(composite, CompositeParamScheduler))
    def test_scheduler_update_interval(self):
        config = self._get_valid_mixed_config()

        # Check default
        scheduler = CompositeParamScheduler.from_config(config)
        self.assertEqual(scheduler.update_interval, UpdateInterval.STEP)

        # Check step
        step_config = copy.deepcopy(config)
        step_config["update_interval"] = "step"
        scheduler = build_param_scheduler(step_config)
        self.assertEqual(scheduler.update_interval, UpdateInterval.STEP)

        # Check epoch
        epoch_config = copy.deepcopy(config)
        epoch_config["update_interval"] = "epoch"
        scheduler = build_param_scheduler(epoch_config)
        self.assertEqual(scheduler.update_interval, UpdateInterval.EPOCH)
Пример #5
0
    def from_config(cls, config: Dict[str, Any]) -> "Adam":
        """Instantiates a Adam from a configuration.

        Args:
            config: A configuration for a Adam.
                See :func:`__init__` for parameters expected in the config.

        Returns:
            A Adam instance.
        """
        # Default params
        config.setdefault("eps", 1e-8)
        config.setdefault("amsgrad", False)

        # Check if betas is a list and convert it to a tuple
        # since a JSON config can only have lists
        if "betas" in config and type(config["betas"]) == list:
            config["betas"] = tuple(config["betas"])

        assert (
            "lr" in config
        ), "Config must contain a learning rate 'lr' section for Adam optimizer"
        assert (
            "betas" in config and type(config["betas"]) == tuple
            and len(config["betas"]) == 2 and type(config["betas"][0]) == float
            and type(config["betas"][1]) == float and config["betas"][0] >= 0.0
            and config["betas"][0] < 1.0 and config["betas"][1] >= 0.0
            and config["betas"][1] < 1.0
        ), "Config must contain a tuple 'betas' in [0, 1) for Adam optimizer"
        assert "weight_decay" in config and is_pos_float(
            config["weight_decay"]
        ), "Config must contain a positive 'weight_decay' for Adam optimizer"

        lr_config = config["lr"]
        if not isinstance(lr_config, dict):
            lr_config = {"name": "constant", "value": lr_config}

        lr_config["num_epochs"] = config["num_epochs"]
        lr_scheduler = build_param_scheduler(lr_config)

        return cls(
            lr_scheduler=lr_scheduler,
            betas=config["betas"],
            eps=config["eps"],
            weight_decay=config["weight_decay"],
            amsgrad=config["amsgrad"],
        )
Пример #6
0
    def from_config(cls, config: Dict[str, Any]) -> "RMSProp":
        """Instantiates a RMSProp from a configuration.

        Args:
            config: A configuration for a RMSProp.
                See :func:`__init__` for parameters expected in the config.

        Returns:
            A RMSProp instance.
        """
        # Default params
        config.setdefault("eps", 1e-8)
        config.setdefault("centered", False)

        assert (
            "lr" in config
        ), "Config must contain a learning rate 'lr' section for RMSProp optimizer"
        for key in ["momentum", "alpha"]:
            assert (
                key in config and config[key] >= 0.0 and config[key] < 1.0
                and type(config[key]) == float
            ), f"Config must contain a '{key}' in [0, 1) for RMSProp optimizer"
        for key in ["weight_decay", "eps"]:
            assert key in config and is_pos_float(
                config[key]
            ), f"Config must contain a positive '{key}' for RMSProp optimizer"
        assert "centered" in config and isinstance(
            config["centered"], bool
        ), "Config must contain a boolean 'centered' param for RMSProp optimizer"

        lr_config = config["lr"]
        if not isinstance(lr_config, dict):
            lr_config = {"name": "constant", "value": lr_config}

        lr_config["num_epochs"] = config["num_epochs"]
        lr_scheduler = build_param_scheduler(lr_config)

        return cls(
            lr_scheduler=lr_scheduler,
            momentum=config["momentum"],
            weight_decay=config["weight_decay"],
            alpha=config["alpha"],
            eps=config["eps"],
            centered=config["centered"],
        )
Пример #7
0
    def from_config(cls, config: Dict[str, Any]) -> "SGD":
        """Instantiates a SGD from a configuration.

        Args:
            config: A configuration for a SGD.
                See :func:`__init__` for parameters expected in the config.

        Returns:
            A SGD instance.
        """
        # Default params
        config["nesterov"] = config.get("nesterov", False)

        assert (
            "lr" in config
        ), "Config must contain a learning rate 'lr' section for SGD optimizer"
        assert (
            "momentum" in config and config["momentum"] >= 0.0
            and config["momentum"] < 1.0 and type(config["momentum"]) == float
        ), "Config must contain a 'momentum' in [0, 1) for SGD optimizer"
        assert "nesterov" in config and isinstance(
            config["nesterov"], bool
        ), "Config must contain a boolean 'nesterov' param for SGD optimizer"
        assert "weight_decay" in config and is_pos_float(
            config["weight_decay"]
        ), "Config must contain a positive 'weight_decay' for SGD optimizer"

        lr_config = config["lr"]
        if not isinstance(lr_config, dict):
            lr_config = {"name": "constant", "value": lr_config}

        lr_config["num_epochs"] = config["num_epochs"]
        lr_scheduler = build_param_scheduler(lr_config)

        return cls(
            lr_scheduler=lr_scheduler,
            momentum=config["momentum"],
            weight_decay=config["weight_decay"],
            nesterov=config["nesterov"],
        )
 def test_build_polynomial_scheduler(self):
     config = self._get_valid_config()
     scheduler = build_param_scheduler(config)
     self.assertTrue(isinstance(scheduler, PolynomialDecayParamScheduler))
 def test_build_non_equi_step_scheduler(self):
     config = self._get_valid_config()
     scheduler = build_param_scheduler(config)
     self.assertTrue(isinstance(scheduler, MultiStepParamScheduler))
Пример #10
0
 def test_build_linear_scheduler(self):
     config = self._get_valid_config()
     scheduler = build_param_scheduler(config)
     self.assertTrue(isinstance(scheduler, LinearParamScheduler))
Пример #11
0
 def test_build_composite_scheduler(self):
     config = self._get_valid_mixed_config()
     scheduler = build_param_scheduler(config)
     self.assertTrue(isinstance(scheduler, CompositeParamScheduler))
 def test_build_cosine_scheduler(self):
     config = self._get_valid_decay_config()
     scheduler = build_param_scheduler(config)
     self.assertTrue(isinstance(scheduler, CosineParamScheduler))
 def test_build_step_with_fixed_gamma_scheduler(self):
     config = self._get_valid_config()
     scheduler = build_param_scheduler(config)
     self.assertTrue(isinstance(scheduler,
                                StepWithFixedGammaParamScheduler))