예제 #1
0
    def get_experiment_name(self,
                            setting: SettingABC,
                            experiment_id: str = None) -> str:
        """Gets a unique name for the experiment where `self` is applied to `setting`.

        This experiment name will be passed to `orion` when performing a run of
        Hyper-Parameter Optimization.

        Parameters
        ----------
        - setting : Setting

            The `Setting` onto which this method will be applied. This method will be used when

        - experiment_id: str, optional

            A custom hash to append to the experiment name. When `None` (default), a
            unique hash will be created based on the values of the Setting's fields.

        Returns
        -------
        str
            The name for the experiment.
        """
        if not experiment_id:
            setting_dict = setting.to_dict()
            # BUG: Some settings have non-string keys/value or something?
            d = flatten_dict(setting_dict)
            experiment_id = compute_identity(size=5, **d)
        assert isinstance(setting.dataset,
                          str), "assuming that dataset is a str for now."
        return (
            f"{self.get_name()}-{setting.get_name()}_{setting.dataset}_{experiment_id}"
        )
예제 #2
0
def cleanup(
        message: Dict[str, Union[Dict, str, float, Any]],
        sep: str = "/",
        keys_to_remove: List[str] = None) -> Dict[str, Union[float, Tensor]]:
    """Cleanup a message dict before it is logged to wandb.

    TODO: Describe what this does in more detail.

    Args:
        message (Dict[str, Union[Dict, str, float, Any]]): [description]
        sep (str, optional): [description]. Defaults to "/".

    Returns:
        Dict[str, Union[float, Tensor]]: Cleaned up dict.
    """
    # Flatten the log dictionary
    from sequoia.utils.utils import flatten_dict

    message = flatten_dict(message, separator=sep)

    keys_to_remove = keys_to_remove or []

    for k in list(message.keys()):
        if any(flag in k for flag in keys_to_remove):
            message.pop(k)
            continue

        v = message.pop(k)
        # Example input:
        # "Task_losses/Task1/losses/Test/losses/rotate/losses/270/metrics/270/accuracy"
        # Simplify the key, by getting rid of all the '/losses/' and '/metrics/' etc.
        things_to_remove: List[str] = [
            f"{sep}losses{sep}", f"{sep}metrics{sep}"
        ]
        for thing in things_to_remove:
            while thing in k:
                k = k.replace(thing, sep)
        # --> "Task_losses/Task1/Test/rotate/270/270/accuracy"

        # Get rid of repetitive modifiers (ex: "/270/270" above)
        parts = k.split(sep)
        parts = [s for s in parts if not s.isspace()]
        k = sep.join(unique_consecutive(parts))
        # Will become:
        # "Task_losses/Task1/Test/rotate/270/accuracy"
        message[k] = v
    return message
예제 #3
0
    def shared_modules(self) -> Dict[str, nn.Module]:
        """Returns any trainable modules in `self` that are shared across tasks.

        By giving this information, these weights can then be used in
        regularization-based auxiliary tasks like EWC, for example.

        For the base model, this returns a dictionary with the encoder, for example.
        When using auxiliaryt tasks, they also add their shared weights, if any.

        Returns
        -------
        Dict[str, nn.Module]:
            Dictionary mapping from name to the shared modules, if any.
        """
        shared_modules = super().shared_modules()
        for task_name, task in self.tasks.items():
            # TODO: What separator to use when dealing with nested dictionaries? I seem
            # to recall that ModuleDicts don't like some separators.
            sep = "."
            task_modules = task.shared_modules()
            flattened_task_modules = flatten_dict(task_modules, separator=sep)
            for module_name, module in flattened_task_modules.items():
                shared_modules[f"{task_name}{sep}{module_name}"] = module
        return shared_modules