def grad_norm(*, model: Model, prefix: str, norm_type: int) -> Dict:
        """Computes gradient norms for a given model.

        Args:
            model (Model): model which gradients to be saved.
            prefix (str): prefix for keys in resulting dictionary.
            norm_type (int): norm type of gradient norm.

        Returns:
            Dict: dictionary in which gradient norms are stored.
        """
        if isinstance(model, (DataParallel, DistributedDataParallel)):
            model = model.module

        total_norm = 0.0
        grad_norm = {}

        for tag, value in model.named_parameters():
            tag = tag.replace(".", "/")
            metrics_tag = f"{prefix}/{tag}"
            param_norm = value.grad.data.norm(norm_type).item()
            total_norm += param_norm ** norm_type
            grad_norm[metrics_tag] = param_norm

        total_norm = total_norm ** (1.0 / norm_type)
        metrics_tag = f"{prefix}/total"
        grad_norm[metrics_tag] = total_norm

        return grad_norm
def set_requires_grad(model: Model, requires_grad: Union[bool, Dict[str,
                                                                    bool]]):
    """Sets the ``requires_grad`` value for all model parameters.

    Example::

        >>> model = SimpleModel()
        >>> set_requires_grad(model, requires_grad=True)
        >>> # or
        >>> model = SimpleModel()
        >>> set_requires_grad(model, requires_grad={""})

    Args:
        model (torch.nn.Module): model
        requires_grad (Union[bool, Dict[str, bool]]): value
    """
    if isinstance(requires_grad, dict):
        for name, param in model.named_parameters():
            assert (name in requires_grad
                    ), f"Parameter `{name}` does not exist in requires_grad"
            param.requires_grad = requires_grad[name]
    else:
        requires_grad = bool(requires_grad)
        for param in model.parameters():
            param.requires_grad = requires_grad
Exemple #3
0
def process_model_params(
    model: Model,
    layerwise_params: Dict[str, dict] = None,
    no_bias_weight_decay: bool = True,
    lr_scaling: float = 1.0,
) -> List[Union[torch.nn.Parameter, dict]]:
    """Gains model parameters for ``torch.optim.Optimizer``.

    Args:
        model (torch.nn.Module): Model to process
        layerwise_params (Dict): Order-sensitive dict where
            each key is regex pattern and values are layer-wise options
            for layers matching with a pattern
        no_bias_weight_decay (bool): If true, removes weight_decay
            for all ``bias`` parameters in the model
        lr_scaling (float): layer-wise learning rate scaling,
            if 1.0, learning rates will not be scaled

    Returns:
        iterable: parameters for an optimizer

    Example::

        >>> model = catalyst.contrib.models.segmentation.ResnetUnet()
        >>> layerwise_params = collections.OrderedDict([
        >>>     ("conv1.*", dict(lr=0.001, weight_decay=0.0003)),
        >>>     ("conv.*", dict(lr=0.002))
        >>> ])
        >>> params = process_model_params(model, layerwise_params)
        >>> optimizer = torch.optim.Adam(params, lr=0.0003)

    """
    params = list(model.named_parameters())
    layerwise_params = layerwise_params or collections.OrderedDict()

    model_params = []
    for name, parameters in params:
        options = {}
        for pattern, options_ in layerwise_params.items():
            if re.match(pattern, name) is not None:
                # all new LR rules write on top of the old ones
                options = merge_dicts(options, options_)

        # no bias decay from https://arxiv.org/abs/1812.01187
        if no_bias_weight_decay and name.endswith("bias"):
            options["weight_decay"] = 0.0

        # lr linear scaling from https://arxiv.org/pdf/1706.02677.pdf
        if "lr" in options:
            options["lr"] *= lr_scaling

        model_params.append({"params": parameters, **options})

    return model_params
def get_requires_grad(model: Model):
    """Gets the ``requires_grad`` value for all model parameters.

    Example::

        >>> model = SimpleModel()
        >>> requires_grad = get_requires_grad(model)

    Args:
        model (torch.nn.Module): model

    Returns:
        requires_grad (Dict[str, bool]): value
    """
    requires_grad = {}
    for name, param in model.named_parameters():
        requires_grad[name] = param.requires_grad
    return requires_grad