def __init__(self,
                 pretrained_model_name: Optional[str] = None,
                 cache_dir: Optional[str] = None,
                 hparams=None):

        super().__init__(hparams=hparams)

        # Create the underlying encoder
        encoder_hparams = dict_fetch(hparams, XLNetEncoder.default_hparams())

        self._encoder = XLNetEncoder(
            pretrained_model_name=pretrained_model_name,
            cache_dir=cache_dir,
            hparams=encoder_hparams)

        # TODO: The logic here is very similar to that in XLNetClassifier.
        #  We need to reduce the code redundancy.
        if self._hparams.use_projection:
            if self._hparams.regr_strategy == 'all_time':
                self.projection = nn.Linear(
                    self._encoder.output_size * self._hparams.max_seq_length,
                    self._encoder.output_size * self._hparams.max_seq_length)
            else:
                self.projection = nn.Linear(self._encoder.output_size,
                                            self._encoder.output_size)
        self.dropout = nn.Dropout(self._hparams.dropout)

        logit_kwargs = self._hparams.logit_layer_kwargs
        if logit_kwargs is None:
            logit_kwargs = {}
        elif not isinstance(logit_kwargs, HParams):
            raise ValueError("hparams['logit_layer_kwargs'] "
                             "must be a dict.")
        else:
            logit_kwargs = logit_kwargs.todict()

        if self._hparams.regr_strategy == 'all_time':
            self.hidden_to_logits = nn.Linear(
                self._encoder.output_size * self._hparams.max_seq_length,
                1, **logit_kwargs)
        else:
            self.hidden_to_logits = nn.Linear(
                self._encoder.output_size, 1, **logit_kwargs)

        if self._hparams.initializer:
            initialize = get_initializer(self._hparams.initializer)
            assert initialize is not None
            if self._hparams.use_projection:
                initialize(self.projection.weight)
                initialize(self.projection.bias)
            initialize(self.hidden_to_logits.weight)
            if self.hidden_to_logits.bias:
                initialize(self.hidden_to_logits.bias)
        else:
            if self._hparams.use_projection:
                self.projection.apply(init_weights)
            self.hidden_to_logits.apply(init_weights)
    def default_hparams() -> Dict[str, Any]:
        r"""Returns a dictionary of hyperparameters with default values.

        .. code-block:: python

            {
                # (1) Same hyperparameters as in XLNetEncoder
                ...
                # (2) Additional hyperparameters
                "regr_strategy": "cls_time",
                "use_projection": True,
                "logit_layer_kwargs": None,
                "name": "xlnet_regressor",
            }

        Here:

        1. Same hyperparameters as in
           :class:`~texar.torch.modules.XLNetEncoder`.
           See the :meth:`~texar.torch.modules.XLNetEncoder.default_hparams`.
           An instance of XLNetEncoder is created for feature extraction.

        2. Additional hyperparameters:

            `"regr_strategy"`: str
                The regression strategy, one of:

                - **cls_time**: Sequence-level regression based on the
                  output of the first time step (which is the `CLS` token).
                  Each sequence has a prediction.
                - **all_time**: Sequence-level regression based on
                  the output of all time steps. Each sequence has a prediction.
                - **time_wise**: Step-wise regression, i.e., make
                  regression for each time step based on its output.

            `"logit_layer_kwargs"`: dict
                Keyword arguments for the logit :torch_nn:`Linear` layer
                constructor. Ignored if no extra logit layer is appended.

            `"use_projection"`: bool
                If `True`, an additional :torch_nn:`Linear` layer is added after
                the summary step.

            `"name"`: str
                Name of the regressor.
        """

        hparams = XLNetEncoder.default_hparams()
        hparams.update(({
            "regr_strategy": "cls_time",
            "use_projection": True,
            "logit_layer_kwargs": None,
            "name": "xlnet_regressor",
        }))
        return hparams
Exemple #3
0
    def default_hparams() -> Dict[str, Any]:
        r"""Returns a dictionary of hyperparameters with default values.

        .. code-block:: python

            {
                # (1) Same hyperparameters as in XLNetEncoder
                ...
                # (2) Additional hyperparameters
                "clas_strategy": "cls_time",
                "use_projection": True,
                "num_classes": 2,
                "name": "xlnet_classifier",
            }

        Here:

        1. Same hyperparameters as in
            :class:`~texar.torch.modules.XLNetEncoder`.
            See the :meth:`~texar.torch.modules.XLNetEncoder.default_hparams`.
            An instance of XLNetEncoder is created for feature extraction.

        2. Additional hyperparameters:

            `"clas_strategy"`: str
                The classification strategy, one of:

                - **cls_time**: Sequence-level classification based on the
                  output of the last time step (which is the `CLS` token).
                  Each sequence has a class.
                - **all_time**: Sequence-level classification based on
                  the output of all time steps. Each sequence has a class.
                - **time_wise**: Step-wise classification, i.e., make
                  classification for each time step based on its output.

            `"use_projection"`: bool
                If `True`, an additional `Linear` layer is added after the
                summary step.

            `"num_classes"`: int
                Number of classes:

                - If **> 0**, an additional :torch_nn:`Linear`
                  layer is appended to the encoder to compute the logits over
                  classes.
                - If **<= 0**, no dense layer is appended. The number of
                  classes is assumed to be the final dense layer size of the
                  encoder.

            `"name"`: str
                Name of the classifier.
        """

        hparams = XLNetEncoder.default_hparams()
        hparams.update({
            "clas_strategy": "cls_time",
            "use_projection": True,
            "num_classes": 2,
            "logit_layer_kwargs": None,
            "name": "xlnet_classifier",
        })
        return hparams