Example #1
0
    def param_groups(self,
                     lr: Optional[float] = None,
                     lr_layer_scale: float = 1.0,
                     decay_base_params: bool = False):
        r"""Create parameter groups for optimizers. When
        :attr:`lr_layer_decay_rate` is not 1.0, parameters from each layer form
        separate groups with different base learning rates.

        The return value of this method can be used in the constructor of
        optimizers, for example:

        .. code-block:: python

            model = XLNetEncoder(...)
            param_groups = model.param_groups(lr=2e-5, lr_layer_scale=0.8)
            optim = torch.optim.Adam(param_groups)

        Args:
            lr (float): The learning rate. Can be omitted if
                :attr:`lr_layer_decay_rate` is 1.0.
            lr_layer_scale (float): Per-layer LR scaling rate. The `i`-th layer
                will be scaled by `lr_layer_scale ^ (num_layers - i - 1)`.
            decay_base_params (bool): If `True`, treat non-layer parameters
                (e.g. embeddings) as if they're in layer 0. If `False`, these
                parameters are not scaled.

        Returns:
            The parameter groups, used as the first argument for optimizers.
        """

        if lr_layer_scale != 1.0:
            if lr is None:
                raise ValueError(
                    "lr must be specified when lr_layer_decay_rate is not 1.0")

            num_layers = self._hparams.num_layers
            base_group = {
                "params":
                params_except_in(self, ['attn_layers', 'ff_layers']),
                "lr":
                lr * (lr_layer_scale**num_layers if decay_base_params else 1.0)
            }
            param_groups = [base_group]
            for idx in range(num_layers):
                decay_rate = lr_layer_scale**(num_layers - idx - 1)
                param_group = {
                    "params": [
                        *self.attn_layers[idx].parameters(),
                        *self.ff_layers[idx].parameters()
                    ],
                    "lr":
                    lr * decay_rate,
                }
                param_groups.append(param_group)
            return param_groups
        return self.parameters()
Example #2
0
    def param_groups(self,
                     lr: Optional[float] = None,
                     lr_layer_scale: float = 1.0,
                     decay_base_params: bool = False):
        r"""Create parameter groups for optimizers. When
        :attr:`lr_layer_decay_rate` is not 1.0, parameters from each layer form
        separate groups with different base learning rates.

        The return value of this method can be used in the constructor of
        optimizers, for example:

        .. code-block:: python

            model = XLNetClassifier(...)
            param_groups = model.param_groups(lr=2e-5, lr_layer_scale=0.8)
            optim = torch.optim.Adam(param_groups)

        Args:
            lr (float): The learning rate. Can be omitted if
                :attr:`lr_layer_decay_rate` is 1.0.
            lr_layer_scale (float): Per-layer LR scaling rate. The `i`-th layer
                will be scaled by `lr_layer_scale ^ (num_layers - i - 1)`.
            decay_base_params (bool): If `True`, treat non-layer parameters
                (e.g. embeddings) as if they're in layer 0. If `False`, these
                parameters are not scaled.

        Returns:
            The parameter groups, used as the first argument for optimizers.
        """

        # TODO: Same logic in XLNetRegressor. Reduce code redundancy.

        if lr_layer_scale != 1.0:
            if lr is None:
                raise ValueError(
                    "lr must be specified when lr_layer_decay_rate is not 1.0")

            fine_tune_group = {
                "params": params_except_in(self, ["_encoder"]),
                "lr": lr
            }
            param_groups = [fine_tune_group]
            param_group = self._encoder.param_groups(lr, lr_layer_scale,
                                                     decay_base_params)
            param_groups.extend(param_group)
            return param_groups
        return self.parameters()