コード例 #1
0
def test_flatten_dict():
    """Validate flatten_dict can handle nested dictionaries and argparse Namespace."""

    # Test basic dict flattening with custom delimiter
    params = {"a": {"b": "c"}}
    params = _flatten_dict(params, "--")

    assert "a" not in params
    assert params["a--b"] == "c"

    # Test complex nested dict flattening
    params = {
        "a": {
            5: {
                "foo": "bar"
            }
        },
        "b": 6,
        "c": {
            7: [1, 2, 3, 4],
            8: "foo",
            9: {
                10: "bar"
            }
        }
    }
    params = _flatten_dict(params)

    assert "a" not in params
    assert params["a/5/foo"] == "bar"
    assert params["b"] == 6
    assert params["c/7"] == [1, 2, 3, 4]
    assert params["c/8"] == "foo"
    assert params["c/9/10"] == "bar"

    # Test flattening of argparse Namespace
    opt = "--max_epochs 1".split(" ")
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parent_parser=parser)
    params = parser.parse_args(opt)
    wrapping_dict = {"params": params}
    params = _flatten_dict(wrapping_dict)

    assert type(params) == dict
    assert params["params/logger"] is True
    assert params["params/gpus"] == "None"
    assert "logger" not in params
    assert "gpus" not in params
コード例 #2
0
 def log_hyperparams(self, params: Union[Dict[str, Any],
                                         Namespace]) -> None:
     # TODO: HACK figure out where this is being set to true
     self.experiment.debug = self.debug
     params = _convert_params(params)
     params = _flatten_dict(params)
     self.experiment.argparse(Namespace(**params))
コード例 #3
0
def test_sanitize_callable_params():
    """Callback function are not serializiable.

    Therefore, we get them a chance to return something and if the returned type is not accepted, return None.
    """
    opt = "--max_epochs 1".split(" ")
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parent_parser=parser)
    params = parser.parse_args(opt)

    def return_something():
        return "something"

    params.something = return_something

    def wrapper_something():
        return return_something

    params.wrapper_something_wo_name = lambda: lambda: "1"
    params.wrapper_something = wrapper_something

    params = _convert_params(params)
    params = _flatten_dict(params)
    params = _sanitize_callable_params(params)
    assert params["gpus"] == "None"
    assert params["something"] == "something"
    assert params["wrapper_something"] == "wrapper_something"
    assert params["wrapper_something_wo_name"] == "<lambda>"
コード例 #4
0
    def log_hyperparams(self, params: Union[Dict[str, Any],
                                            Namespace]) -> None:
        params = _convert_params(params)
        params = _flatten_dict(params)
        for k, v in params.items():
            if len(str(v)) > 250:
                rank_zero_warn(
                    f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}",
                    category=RuntimeWarning)
                continue

            self.experiment.log_param(self.run_id, k, v)
コード例 #5
0
    def log_hyperparams(self,
                        params: Union[Dict[str, Any], Namespace],
                        metrics: Optional[Dict[str, Any]] = None) -> None:
        """Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the
        hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs
        to display the new ones with hyperparameters.

        Args:
            params: a dictionary-like container with the hyperparameters
            metrics: Dictionary with metric names as keys and measured quantities as values
        """

        params = _convert_params(params)

        # store params to output
        if _OMEGACONF_AVAILABLE and isinstance(params, Container):
            self.hparams = OmegaConf.merge(self.hparams, params)
        else:
            self.hparams.update(params)

        # format params into the suitable for tensorboard
        params = _flatten_dict(params)
        params = self._sanitize_params(params)

        if metrics is None:
            if self._default_hp_metric:
                metrics = {"hp_metric": -1}
        elif not isinstance(metrics, dict):
            metrics = {"hp_metric": metrics}

        if metrics:
            self.log_metrics(metrics, 0)
            exp, ssi, sei = hparams(params, metrics)
            writer = self.experiment._get_file_writer()
            writer.add_summary(exp)
            writer.add_summary(ssi)
            writer.add_summary(sei)
コード例 #6
0
 def log_hyperparams(self, params: Union[Dict[str, Any],
                                         Namespace]) -> None:
     params = _convert_params(params)
     params = _flatten_dict(params)
     params = _sanitize_callable_params(params)
     self.experiment.config.update(params, allow_val_change=True)
コード例 #7
0
ファイル: comet.py プロジェクト: ashleve/pytorch-lightning
 def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
     params = _convert_params(params)
     params = _flatten_dict(params)
     self.experiment.log_parameters(params)