예제 #1
0
    def test_trainable_variables(self):
        r"""Tests the functionality of automatically collecting trainable
        variables.
        """
        # case 1: xlnet base
        encoder = XLNetEncoder()
        self.assertEqual(len(encoder.trainable_variables), 182)
        _, _ = encoder(self.inputs)

        # Case 2: xlnet large
        hparams = {
            "pretrained_model_name": "xlnet-large-cased",
        }
        encoder = XLNetEncoder(hparams=hparams)
        self.assertEqual(len(encoder.trainable_variables), 362)
        _, _ = encoder(self.inputs)

        # case 3: self-designed bert
        hparams = {
            "num_layers": 6,
            "pretrained_model_name": None,
        }
        encoder = XLNetEncoder(hparams=hparams)
        self.assertEqual(len(encoder.trainable_variables), 92)
        _, _ = encoder(self.inputs)
예제 #2
0
    def test_hparams(self):
        r"""Tests the priority of the encoder arch parameter.
        """
        # case 1: set "pretrained_mode_name" by constructor argument
        hparams = {
            "pretrained_model_name": "xlnet-large-cased",
        }
        encoder = XLNetEncoder(pretrained_model_name="xlnet-base-cased",
                               hparams=hparams)
        self.assertEqual(encoder.hparams.num_layers, 12)
        _, _ = encoder(self.inputs)

        # case 2: set "pretrained_mode_name" by hparams
        hparams = {
            "pretrained_model_name": "xlnet-large-cased",
            "num_layers": 6,
        }
        encoder = XLNetEncoder(hparams=hparams)
        self.assertEqual(encoder.hparams.num_layers, 24)
        _, _ = encoder(self.inputs)

        # case 3: set to None in both hparams and constructor argument
        hparams = {
            "pretrained_model_name": None,
            "num_layers": 6,
        }
        encoder = XLNetEncoder(hparams=hparams)
        self.assertEqual(encoder.hparams.num_layers, 6)
        _, _ = encoder(self.inputs)

        # case 4: using default hparams
        encoder = XLNetEncoder()
        self.assertEqual(encoder.hparams.num_layers, 12)
        _, _ = encoder(self.inputs)
예제 #3
0
    def __init__(self,
                 pretrained_model_name: Optional[str] = None,
                 cache_dir: Optional[str] = None,
                 hparams=None):

        super().__init__(hparams=hparams)

        # Create the underlying encoder
        encoder_hparams = dict_fetch(hparams, XLNetEncoder.default_hparams())

        self._encoder = XLNetEncoder(
            pretrained_model_name=pretrained_model_name,
            cache_dir=cache_dir,
            hparams=encoder_hparams)

        # TODO: The logic here is very similar to that in XLNetClassifier.
        #  We need to reduce the code redundancy.
        if self._hparams.use_projection:
            if self._hparams.regr_strategy == 'all_time':
                self.projection = nn.Linear(
                    self._encoder.output_size * self._hparams.max_seq_length,
                    self._encoder.output_size * self._hparams.max_seq_length)
            else:
                self.projection = nn.Linear(self._encoder.output_size,
                                            self._encoder.output_size)
        self.dropout = nn.Dropout(self._hparams.dropout)

        logit_kwargs = self._hparams.logit_layer_kwargs
        if logit_kwargs is None:
            logit_kwargs = {}
        elif not isinstance(logit_kwargs, HParams):
            raise ValueError("hparams['logit_layer_kwargs'] "
                             "must be a dict.")
        else:
            logit_kwargs = logit_kwargs.todict()

        if self._hparams.regr_strategy == 'all_time':
            self.hidden_to_logits = nn.Linear(
                self._encoder.output_size * self._hparams.max_seq_length,
                1, **logit_kwargs)
        else:
            self.hidden_to_logits = nn.Linear(
                self._encoder.output_size, 1, **logit_kwargs)

        if self._hparams.initializer:
            initialize = get_initializer(self._hparams.initializer)
            assert initialize is not None
            if self._hparams.use_projection:
                initialize(self.projection.weight)
                initialize(self.projection.bias)
            initialize(self.hidden_to_logits.weight)
            if self.hidden_to_logits.bias:
                initialize(self.hidden_to_logits.bias)
        else:
            if self._hparams.use_projection:
                self.projection.apply(init_weights)
            self.hidden_to_logits.apply(init_weights)
예제 #4
0
    def test_model_loading(self):
        r"""Tests model loading functionality."""
        # case 1
        encoder = XLNetEncoder(pretrained_model_name="xlnet-base-cased")
        _, _ = encoder(self.inputs)

        # case 2
        encoder = XLNetEncoder(pretrained_model_name="xlnet-large-cased")
        _, _ = encoder(self.inputs)
예제 #5
0
    def test_encode(self):
        r"""Tests encoding.
        """
        # case 1: xlnet base
        hparams = {
            "pretrained_model_name": None,
        }
        encoder = XLNetEncoder(hparams=hparams)

        inputs = torch.randint(32000, (self.batch_size, self.max_length))
        outputs, new_memory = encoder(inputs)

        self.assertEqual(
            outputs.shape,
            torch.Size([self.batch_size, self.max_length,
                        encoder.output_size]))
        self.assertEqual(new_memory, None)

        # case 2: self-designed xlnet
        hparams = {
            'pretrained_model_name': None,
            'untie_r': True,
            'num_layers': 6,
            'mem_len': 0,
            'reuse_len': 0,
            'num_heads': 8,
            'hidden_dim': 32,
            'head_dim': 64,
            'dropout': 0.1,
            'attention_dropout': 0.1,
            'use_segments': True,
            'ffn_inner_dim': 256,
            'activation': 'gelu',
            'vocab_size': 32000,
            'max_seq_length': 128,
            'initializer': None,
            'name': "xlnet_encoder",
        }
        encoder = XLNetEncoder(hparams=hparams)
        outputs, new_memory = encoder(inputs)

        self.assertEqual(
            outputs.shape,
            torch.Size([self.batch_size, self.max_length,
                        encoder.output_size]))
        self.assertEqual(new_memory, None)
예제 #6
0
    def default_hparams() -> Dict[str, Any]:
        r"""Returns a dictionary of hyperparameters with default values.

        .. code-block:: python

            {
                # (1) Same hyperparameters as in XLNetEncoder
                ...
                # (2) Additional hyperparameters
                "regr_strategy": "cls_time",
                "use_projection": True,
                "logit_layer_kwargs": None,
                "name": "xlnet_regressor",
            }

        Here:

        1. Same hyperparameters as in
           :class:`~texar.torch.modules.XLNetEncoder`.
           See the :meth:`~texar.torch.modules.XLNetEncoder.default_hparams`.
           An instance of XLNetEncoder is created for feature extraction.

        2. Additional hyperparameters:

            `"regr_strategy"`: str
                The regression strategy, one of:

                - **cls_time**: Sequence-level regression based on the
                  output of the first time step (which is the `CLS` token).
                  Each sequence has a prediction.
                - **all_time**: Sequence-level regression based on
                  the output of all time steps. Each sequence has a prediction.
                - **time_wise**: Step-wise regression, i.e., make
                  regression for each time step based on its output.

            `"logit_layer_kwargs"`: dict
                Keyword arguments for the logit :torch_nn:`Linear` layer
                constructor. Ignored if no extra logit layer is appended.

            `"use_projection"`: bool
                If `True`, an additional :torch_nn:`Linear` layer is added after
                the summary step.

            `"name"`: str
                Name of the regressor.
        """

        hparams = XLNetEncoder.default_hparams()
        hparams.update(({
            "regr_strategy": "cls_time",
            "use_projection": True,
            "logit_layer_kwargs": None,
            "name": "xlnet_regressor",
        }))
        return hparams
예제 #7
0
    def test_soft_ids(self):
        r"""Tests soft ids.
        """
        hparams = {
            "pretrained_model_name": None,
        }
        encoder = XLNetEncoder(hparams=hparams)

        inputs = torch.rand(self.batch_size, self.max_length, 32000)
        outputs, new_memory = encoder(inputs)

        self.assertEqual(
            outputs.shape,
            torch.Size([self.batch_size, self.max_length,
                        encoder.output_size]))
        self.assertEqual(new_memory, None)
예제 #8
0
class XLNetClassifier(ClassifierBase, PretrainedXLNetMixin):
    r"""Classifier based on XLNet modules. Please see
    :class:`~texar.torch.modules.PretrainedXLNetMixin` for a brief description
    of XLNet.

    Arguments are the same as in
    :class:`~texar.torch.modules.XLNetEncoder`.

    Args:
        pretrained_model_name (optional): a `str`, the name
            of pre-trained model (e.g., ``xlnet-based-cased``). Please refer to
            :class:`~texar.torch.modules.PretrainedXLNetMixin` for
            all supported models.
            If `None`, the model name in :attr:`hparams` is used.
        cache_dir (optional): the path to a folder in which the
            pre-trained models will be cached. If `None` (default),
            a default directory (``texar_data`` folder under user's home
            directory) will be used.
        hparams (dict or HParams, optional): Hyperparameters. Missing
            hyperparameters will be set to default values. See
            :meth:`default_hparams` for the hyperparameter structure
            and default values.
    """

    def __init__(self,
                 pretrained_model_name: Optional[str] = None,
                 cache_dir: Optional[str] = None,
                 hparams=None):

        super().__init__(hparams=hparams)

        # Create the underlying encoder
        encoder_hparams = dict_fetch(hparams, XLNetEncoder.default_hparams())

        self._encoder = XLNetEncoder(
            pretrained_model_name=pretrained_model_name,
            cache_dir=cache_dir,
            hparams=encoder_hparams)

        # TODO: The logic here is very similar to that in XLNetRegressor.
        #  We need to reduce the code redundancy.
        if self._hparams.use_projection:
            if self._hparams.clas_strategy == 'all_time':
                self.projection = nn.Linear(
                    self._encoder.output_size * self._hparams.max_seq_length,
                    self._encoder.output_size * self._hparams.max_seq_length)
            else:
                self.projection = nn.Linear(self._encoder.output_size,
                                            self._encoder.output_size)
        self.dropout = nn.Dropout(self._hparams.dropout)

        # Create an additional classification layer if needed
        self.num_classes = self._hparams.num_classes
        if self.num_classes <= 0:
            self.hidden_to_logits = None
        else:
            logit_kwargs = self._hparams.logit_layer_kwargs
            if logit_kwargs is None:
                logit_kwargs = {}
            elif not isinstance(logit_kwargs, HParams):
                raise ValueError("hparams['logit_layer_kwargs'] "
                                 "must be a dict.")
            else:
                logit_kwargs = logit_kwargs.todict()

            if self._hparams.clas_strategy == 'all_time':
                self.hidden_to_logits = nn.Linear(
                    self._encoder.output_size * self._hparams.max_seq_length,
                    self.num_classes,
                    **logit_kwargs)
            else:
                self.hidden_to_logits = nn.Linear(
                    self._encoder.output_size, self.num_classes,
                    **logit_kwargs)

        if self._hparams.initializer:
            initialize = get_initializer(self._hparams.initializer)
            assert initialize is not None
            if self._hparams.use_projection:
                initialize(self.projection.weight)
                initialize(self.projection.bias)
            if self.hidden_to_logits:
                initialize(self.hidden_to_logits.weight)
                if self.hidden_to_logits.bias:
                    initialize(self.hidden_to_logits.bias)
        else:
            if self._hparams.use_projection:
                self.projection.apply(init_weights)
            if self.hidden_to_logits:
                self.hidden_to_logits.apply(init_weights)

        self.is_binary = ((self.num_classes == 1) or
                          (self.num_classes <= 0 and
                           self._hparams.hidden_dim == 1))

    @staticmethod
    def default_hparams() -> Dict[str, Any]:
        r"""Returns a dictionary of hyperparameters with default values.

        .. code-block:: python

            {
                # (1) Same hyperparameters as in XLNetEncoder
                ...
                # (2) Additional hyperparameters
                "clas_strategy": "cls_time",
                "use_projection": True,
                "num_classes": 2,
                "name": "xlnet_classifier",
            }

        Here:

        1. Same hyperparameters as in
            :class:`~texar.torch.modules.XLNetEncoder`.
            See the :meth:`~texar.torch.modules.XLNetEncoder.default_hparams`.
            An instance of XLNetEncoder is created for feature extraction.

        2. Additional hyperparameters:

            `"clas_strategy"`: str
                The classification strategy, one of:

                - **cls_time**: Sequence-level classification based on the
                  output of the last time step (which is the `CLS` token).
                  Each sequence has a class.
                - **all_time**: Sequence-level classification based on
                  the output of all time steps. Each sequence has a class.
                - **time_wise**: Step-wise classification, i.e., make
                  classification for each time step based on its output.

            `"use_projection"`: bool
                If `True`, an additional `Linear` layer is added after the
                summary step.

            `"num_classes"`: int
                Number of classes:

                - If **> 0**, an additional :torch_nn:`Linear`
                  layer is appended to the encoder to compute the logits over
                  classes.
                - If **<= 0**, no dense layer is appended. The number of
                  classes is assumed to be the final dense layer size of the
                  encoder.

            `"name"`: str
                Name of the classifier.
        """

        hparams = XLNetEncoder.default_hparams()
        hparams.update({
            "clas_strategy": "cls_time",
            "use_projection": True,
            "num_classes": 2,
            "logit_layer_kwargs": None,
            "name": "xlnet_classifier",
        })
        return hparams

    def param_groups(self,
                     lr: Optional[float] = None,
                     lr_layer_scale: float = 1.0,
                     decay_base_params: bool = False):
        r"""Create parameter groups for optimizers. When
        :attr:`lr_layer_decay_rate` is not 1.0, parameters from each layer form
        separate groups with different base learning rates.

        The return value of this method can be used in the constructor of
        optimizers, for example:

        .. code-block:: python

            model = XLNetClassifier(...)
            param_groups = model.param_groups(lr=2e-5, lr_layer_scale=0.8)
            optim = torch.optim.Adam(param_groups)

        Args:
            lr (float): The learning rate. Can be omitted if
                :attr:`lr_layer_decay_rate` is 1.0.
            lr_layer_scale (float): Per-layer LR scaling rate. The `i`-th layer
                will be scaled by `lr_layer_scale ^ (num_layers - i - 1)`.
            decay_base_params (bool): If `True`, treat non-layer parameters
                (e.g. embeddings) as if they're in layer 0. If `False`, these
                parameters are not scaled.

        Returns:
            The parameter groups, used as the first argument for optimizers.
        """

        # TODO: Same logic in XLNetRegressor. Reduce code redundancy.

        if lr_layer_scale != 1.0:
            if lr is None:
                raise ValueError(
                    "lr must be specified when lr_layer_decay_rate is not 1.0")

            fine_tune_group = {
                "params": params_except_in(self, ["_encoder"]),
                "lr": lr
            }
            param_groups = [fine_tune_group]
            param_group = self._encoder.param_groups(lr, lr_layer_scale,
                                                     decay_base_params)
            param_groups.extend(param_group)
            return param_groups
        return self.parameters()

    def forward(self,  # type: ignore
                inputs: Union[torch.Tensor, torch.LongTensor],
                segment_ids: Optional[torch.LongTensor] = None,
                input_mask: Optional[torch.Tensor] = None) \
            -> Tuple[torch.Tensor, torch.LongTensor]:
        r"""Feeds the inputs through the network and makes classification.

        Args:
            inputs: Either a **2D Tensor** of shape `[batch_size, max_time]`,
                containing the ids of tokens in input sequences, or
                a **3D Tensor** of shape `[batch_size, max_time, vocab_size]`,
                containing soft token ids (i.e., weights or probabilities)
                used to mix the embedding vectors.
            segment_ids: Shape `[batch_size, max_time]`.
            input_mask: Float tensor of shape `[batch_size, max_time]`. Note
                that positions with value 1 are masked out.

        Returns:
            A tuple `(logits, preds)`, containing the logits over classes and
            the predictions, respectively.

            - If ``clas_strategy`` is ``cls_time`` or ``all_time``:

                - If ``num_classes`` == 1, ``logits`` and ``pred`` are both of
                  shape ``[batch_size]``.
                - If ``num_classes`` > 1, ``logits`` is of shape
                  ``[batch_size, num_classes]`` and ``pred`` is of shape
                  ``[batch_size]``.

            - If ``clas_strategy`` is ``time_wise``:

                - ``num_classes`` == 1, ``logits`` and ``pred`` are both of
                  shape ``[batch_size, max_time]``.
                - If ``num_classes`` > 1, ``logits`` is of shape
                  ``[batch_size, max_time, num_classes]`` and ``pred`` is of
                  shape ``[batch_size, max_time]``.
        """
        # output: [batch_size, seq_len, hidden_dim]
        output, _ = self._encoder(inputs=inputs,
                                  segment_ids=segment_ids,
                                  input_mask=input_mask)

        strategy = self._hparams.clas_strategy
        if strategy == 'time_wise':
            summary = output
        elif strategy == 'cls_time':
            summary = output[:, -1]
        elif strategy == 'all_time':
            length_diff = self._hparams.max_seq_length - inputs.shape[1]
            summary_input = F.pad(output, [0, 0, 0, length_diff, 0, 0])
            summary_input_dim = (self._encoder.output_size *
                                 self._hparams.max_seq_length)

            summary = summary_input.contiguous().view(-1, summary_input_dim)
        else:
            raise ValueError(f"Unknown classification strategy: {strategy}.")

        if self._hparams.use_projection:
            summary = torch.tanh(self.projection(summary))

        if self.hidden_to_logits is not None:
            summary = self.dropout(summary)
            logits = self.hidden_to_logits(summary)
        else:
            logits = summary

        # Compute predictions
        if strategy == "time_wise":
            if self.is_binary:
                logits = torch.squeeze(logits, -1)
                preds = (logits > 0).long()
            else:
                preds = torch.argmax(logits, dim=-1)
        else:
            if self.is_binary:
                preds = (logits > 0).long()
                logits = torch.flatten(logits)
            else:
                preds = torch.argmax(logits, dim=-1)
            preds = torch.flatten(preds)

        return logits, preds

    @property
    def output_size(self) -> int:
        r"""The feature size of :meth:`forward` output :attr:`logits`.
        If :attr:`logits` size is only determined by input
        (i.e. if ``num_classes`` == 1), the feature size is equal to ``-1``.
        Otherwise it is equal to last dimension value of :attr:`logits` size.
        """
        if self._hparams.num_classes > 1:
            logit_dim = self._hparams.num_classes
        elif self._hparams.num_classes == 1:
            logit_dim = -1
        else:
            logit_dim = self._hparams.hidden_dim
        return logit_dim
예제 #9
0
    def default_hparams() -> Dict[str, Any]:
        r"""Returns a dictionary of hyperparameters with default values.

        .. code-block:: python

            {
                # (1) Same hyperparameters as in XLNetEncoder
                ...
                # (2) Additional hyperparameters
                "clas_strategy": "cls_time",
                "use_projection": True,
                "num_classes": 2,
                "name": "xlnet_classifier",
            }

        Here:

        1. Same hyperparameters as in
            :class:`~texar.torch.modules.XLNetEncoder`.
            See the :meth:`~texar.torch.modules.XLNetEncoder.default_hparams`.
            An instance of XLNetEncoder is created for feature extraction.

        2. Additional hyperparameters:

            `"clas_strategy"`: str
                The classification strategy, one of:

                - **cls_time**: Sequence-level classification based on the
                  output of the last time step (which is the `CLS` token).
                  Each sequence has a class.
                - **all_time**: Sequence-level classification based on
                  the output of all time steps. Each sequence has a class.
                - **time_wise**: Step-wise classification, i.e., make
                  classification for each time step based on its output.

            `"use_projection"`: bool
                If `True`, an additional `Linear` layer is added after the
                summary step.

            `"num_classes"`: int
                Number of classes:

                - If **> 0**, an additional :torch_nn:`Linear`
                  layer is appended to the encoder to compute the logits over
                  classes.
                - If **<= 0**, no dense layer is appended. The number of
                  classes is assumed to be the final dense layer size of the
                  encoder.

            `"name"`: str
                Name of the classifier.
        """

        hparams = XLNetEncoder.default_hparams()
        hparams.update({
            "clas_strategy": "cls_time",
            "use_projection": True,
            "num_classes": 2,
            "logit_layer_kwargs": None,
            "name": "xlnet_classifier",
        })
        return hparams
예제 #10
0
class XLNetRegressor(RegressorBase):
    r"""Regressor based on XLNet modules.

    Arguments are the same as in
    :class:`~texar.torch.modules.XLNetEncoder`.

    Args:
        pretrained_model_name (optional): a str with the name
            of a pre-trained model to load selected in the list of:
            `xlnet-base-cased`, `xlnet-large-cased`.
            If `None`, will use the model name in :attr:`hparams`.
        cache_dir (optional): the path to a folder in which the
            pre-trained models will be cached. If `None` (default),
            a default directory will be used.
        hparams (dict or HParams, optional): Hyperparameters. Missing
            hyperparameters will be set to default values. See
            :meth:`default_hparams` for the hyperparameter structure
            and default values.
    """
    def __init__(self,
                 pretrained_model_name: Optional[str] = None,
                 cache_dir: Optional[str] = None,
                 hparams=None):

        super().__init__(hparams=hparams)

        # Create the underlying encoder
        encoder_hparams = dict_fetch(hparams, XLNetEncoder.default_hparams())

        self._encoder = XLNetEncoder(
            pretrained_model_name=pretrained_model_name,
            cache_dir=cache_dir,
            hparams=encoder_hparams)

        # TODO: The logic here is very similar to that in XLNetClassifier.
        #  We need to reduce the code redundancy.
        if self._hparams.use_projection:
            if self._hparams.regr_strategy == 'all_time':
                self.projection = nn.Linear(
                    self._encoder.output_size * self._hparams.max_seq_length,
                    self._encoder.output_size * self._hparams.max_seq_length)
            else:
                self.projection = nn.Linear(self._encoder.output_size,
                                            self._encoder.output_size)
        self.dropout = nn.Dropout(self._hparams.dropout)

        logit_kwargs = self._hparams.logit_layer_kwargs
        if logit_kwargs is None:
            logit_kwargs = {}
        elif not isinstance(logit_kwargs, HParams):
            raise ValueError("hparams['logit_layer_kwargs'] "
                             "must be a dict.")
        else:
            logit_kwargs = logit_kwargs.todict()

        if self._hparams.regr_strategy == 'all_time':
            self.hidden_to_logits = nn.Linear(
                self._encoder.output_size * self._hparams.max_seq_length, 1,
                **logit_kwargs)
        else:
            self.hidden_to_logits = nn.Linear(self._encoder.output_size, 1,
                                              **logit_kwargs)

        if self._hparams.initializer:
            initialize = get_initializer(self._hparams.initializer)
            assert initialize is not None
            if self._hparams.use_projection:
                initialize(self.projection.weight)
                initialize(self.projection.bias)
            initialize(self.hidden_to_logits.weight)
            if self.hidden_to_logits.bias:
                initialize(self.hidden_to_logits.bias)
        else:
            if self._hparams.use_projection:
                self.projection.apply(init_weights)
            self.hidden_to_logits.apply(init_weights)

    @staticmethod
    def default_hparams() -> Dict[str, Any]:
        r"""Returns a dictionary of hyperparameters with default values.

        .. code-block:: python

            {
                # (1) Same hyperparameters as in XLNetEncoder
                ...
                # (2) Additional hyperparameters
                "regr_strategy": "cls_time",
                "use_projection": True,
                "logit_layer_kwargs": None,
                "name": "xlnet_regressor",
            }

        Here:

        1. Same hyperparameters as in
           :class:`~texar.torch.modules.XLNetEncoder`.
           See the :meth:`~texar.torch.modules.XLNetEncoder.default_hparams`.
           An instance of XLNetEncoder is created for feature extraction.

        2. Additional hyperparameters:

            `"regr_strategy"`: str
                The regression strategy, one of:

                - **cls_time**: Sequence-level regression based on the
                  output of the first time step (which is the `CLS` token).
                  Each sequence has a prediction.
                - **all_time**: Sequence-level regression based on
                  the output of all time steps. Each sequence has a prediction.
                - **time_wise**: Step-wise regression, i.e., make
                  regression for each time step based on its output.

            `"logit_layer_kwargs"`: dict
                Keyword arguments for the logit :torch_nn:`Linear` layer
                constructor. Ignored if no extra logit layer is appended.

            `"use_projection"`: bool
                If `True`, an additional :torch_nn:`Linear` layer is added after
                the summary step.

            `"name"`: str
                Name of the regressor.
        """

        hparams = XLNetEncoder.default_hparams()
        hparams.update(({
            "regr_strategy": "cls_time",
            "use_projection": True,
            "logit_layer_kwargs": None,
            "name": "xlnet_regressor",
        }))
        return hparams

    def param_groups(self,
                     lr: Optional[float] = None,
                     lr_layer_scale: float = 1.0,
                     decay_base_params: bool = False):
        r"""Create parameter groups for optimizers. When
        :attr:`lr_layer_decay_rate` is not 1.0, parameters from each layer form
        separate groups with different base learning rates.

        Args:
            lr (float): The learning rate. Can be omitted if
                :attr:`lr_layer_decay_rate` is 1.0.
            lr_layer_scale (float): Per-layer LR scaling rate. The `i`-th layer
                will be scaled by `lr_layer_scale ^ (num_layers - i - 1)`.
            decay_base_params (bool): If `True`, treat non-layer parameters
                (e.g. embeddings) as if they're in layer 0. If `False`, these
                parameters are not scaled.

        Returns:
            The parameter groups, used as the first argument for optimizers.
        """

        # TODO: Same logic in XLNetClassifier. Reduce code redundancy.

        if lr_layer_scale != 1.0:
            if lr is None:
                raise ValueError(
                    "lr must be specified when lr_layer_decay_rate is not 1.0")

            fine_tune_group = {
                "params": params_except_in(self, ["_encoder"]),
                "lr": lr
            }
            param_groups = [fine_tune_group]
            param_group = self._encoder.param_groups(lr, lr_layer_scale,
                                                     decay_base_params)
            param_groups.extend(param_group)
        else:
            param_groups = self.parameters()
        return param_groups

    def forward(
            self,  # type: ignore
            token_ids: torch.LongTensor,
            segment_ids: Optional[torch.LongTensor] = None,
            input_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        r"""Feeds the inputs through the network and makes regression.

        Args:
            token_ids: Shape `[batch_size, max_time]`.
            segment_ids: Shape `[batch_size, max_time]`.
            input_mask: Float tensor of shape `[batch_size, max_time]`. Note
                that positions with value 1 are masked out.

        Returns:
            Regression predictions.

            - If ``regr_strategy`` is ``cls_time`` or ``all_time``, predictions
              have shape `[batch_size]`.

            - If ``clas_strategy`` is ``time_wise``, predictions have shape
              `[batch_size, max_time]`.
        """
        # output: [batch_size, seq_len, hidden_dim]
        output, _ = self._encoder(token_ids=token_ids,
                                  segment_ids=segment_ids,
                                  input_mask=input_mask)

        strategy = self._hparams.regr_strategy
        if strategy == 'time_wise':
            summary = output
        elif strategy == 'cls_time':
            summary = output[:, -1]
        elif strategy == 'all_time':
            length_diff = self._hparams.max_seq_length - token_ids.shape[1]
            summary_input = F.pad(output, [0, 0, 0, length_diff, 0, 0])
            summary_input_dim = (self._encoder.output_size *
                                 self._hparams.max_seq_length)

            summary = summary_input.contiguous().view(-1, summary_input_dim)
        else:
            raise ValueError(
                'Unknown regression strategy: {}'.format(strategy))

        if self._hparams.use_projection:
            summary = torch.tanh(self.projection(summary))

        summary = self.dropout(summary)

        preds = self.hidden_to_logits(summary).squeeze(-1)

        return preds

    @property
    def output_size(self) -> int:
        return 1
예제 #11
0
 def test_model_loading(self):
     r"""Tests model loading functionality."""
     for pretrained_model_name in XLNetEncoder.available_checkpoints():
         encoder = XLNetEncoder(pretrained_model_name=pretrained_model_name)
         _ = encoder(self.inputs)