def test_decode_infer(self):
        r"""Tests decoding in inference mode."""
        decoder = BasicRNNDecoder(input_size=self._emb_dim,
                                  vocab_size=self._vocab_size,
                                  hparams=self._hparams)

        decoder.eval()
        start_tokens = torch.tensor([self._vocab_size - 2] * self._batch_size)

        helpers = []
        for strategy in ['infer_greedy', 'infer_sample']:
            helper = decoder.create_helper(
                decoding_strategy=strategy,
                embedding=self._embedding,
                start_tokens=start_tokens,
                end_token=self._vocab_size - 1)
            helpers.append(helper)
        for klass in ['TopKSampleEmbeddingHelper', 'SoftmaxEmbeddingHelper',
                      'GumbelSoftmaxEmbeddingHelper']:
            helper = get_helper(
                klass, embedding=self._embedding,
                start_tokens=start_tokens, end_token=self._vocab_size - 1,
                top_k=self._vocab_size // 2, tau=2.0,
                straight_through=True)
            helpers.append(helper)

        for helper in helpers:
            max_length = 100
            outputs, final_state, sequence_lengths = decoder(
                helper=helper, max_decoding_length=max_length)
            self.assertLessEqual(max(sequence_lengths), max_length)
            self._test_outputs(decoder, outputs, final_state, sequence_lengths,
                               test_mode=True, helper=helper)
示例#2
0
    def create_helper(self,
                      *,
                      decoding_strategy: Optional[str] = None,
                      embedding: Optional[Embedding] = None,
                      start_tokens: Optional[torch.LongTensor] = None,
                      end_token: Optional[int] = None,
                      softmax_temperature: Optional[float] = None,
                      infer_mode: Optional[bool] = None,
                      **kwargs) -> Helper:
        r"""Create a helper instance for the decoder. This is a shared interface
        for both :class:`~texar.modules.BasicRNNDecoder` and
        :class:`~texar.modules.AttentionRNNDecoder`.

        The function provides **3 ways** to specify the
        decoding method, with varying flexibility:

        1. The :attr:`decoding_strategy` argument: A string taking value of:

            - **"train_greedy"**: decoding in teacher-forcing fashion (i.e.,
              feeding `ground truth` to decode the next step), and each sample
              is obtained by taking the `argmax` of the output logits.
              Arguments :attr:`(inputs, sequence_length)`
              are required for this strategy, and argument :attr:`embedding`
              is optional.
            - **"infer_greedy"**: decoding in inference fashion (i.e., feeding
              the `generated` sample to decode the next step), and each sample
              is obtained by taking the `argmax` of the output logits.
              Arguments :attr:`(embedding, start_tokens, end_token)` are
              required for this strategy, and argument
              :attr:`max_decoding_length` is optional.
            - **"infer_sample"**: decoding in inference fashion, and each
              sample is obtained by `random sampling` from the RNN output
              distribution. Arguments
              :attr:`(embedding, start_tokens, end_token)` are
              required for this strategy, and argument
              :attr:`max_decoding_length` is optional.

          This argument is used only when argument :attr:`helper` is `None`.

          Example:

            .. code-block:: python

                embedder = WordEmbedder(vocab_size=data.vocab.size)
                decoder = BasicRNNDecoder(vocab_size=data.vocab.size)

                # Teacher-forcing decoding
                outputs_1, _, _ = decoder(
                    decoding_strategy='train_greedy',
                    inputs=embedder(data_batch['text_ids']),
                    sequence_length=data_batch['length'] - 1)

                # Random sample decoding. Gets 100 sequence samples
                outputs_2, _, sequence_length = decoder(
                    decoding_strategy='infer_sample',
                    start_tokens=[data.vocab.bos_token_id] * 100,
                    end_token=data.vocab.eos.token_id,
                    embedding=embedder,
                    max_decoding_length=60)

        2. The :attr:`helper` argument: An instance of subclass of
           :class:`~texar.modules.decoders.rnn_decoder_helpers.Helper`. This
           provides a superset of decoding strategies than above, for example:

            - :class:`~texar.modules.TrainingHelper` corresponding to the
              "train_greedy" strategy.
            - :class:`~texar.modules.ScheduledEmbeddingTrainingHelper` and
              :class:`~texar.modules.ScheduledOutputTrainingHelper` for
              scheduled sampling.
            - :class:`~texar.modules.SoftmaxEmbeddingHelper` and
              :class:`~texar.modules.GumbelSoftmaxEmbeddingHelper` for
              soft decoding and gradient backpropagation.

          This means gives the maximal flexibility of configuring the decoding
          strategy.

          Example:

            .. code-block:: python

                embedder = WordEmbedder(vocab_size=data.vocab.size)
                decoder = BasicRNNDecoder(vocab_size=data.vocab.size)

                # Teacher-forcing decoding, same as above with
                # `decoding_strategy='train_greedy'`
                helper_1 = TrainingHelper(
                    inputs=embedders(data_batch['text_ids']),
                    sequence_length=data_batch['length'] - 1)
                outputs_1, _, _ = decoder(helper=helper_1)

                # Gumbel-softmax decoding
                helper_2 = GumbelSoftmaxEmbeddingHelper(
                    embedding=embedder,
                    start_tokens=[data.vocab.bos_token_id] * 100,
                    end_token=data.vocab.eos_token_id,
                    tau=0.1)
                outputs_2, _, sequence_length = decoder(
                    max_decoding_length=60, helper=helper_2)

        3. ``hparams["helper_train"]`` and ``hparams["helper_infer"]``:
           Specifying the helper through hyperparameters. Train and infer
           strategy is toggled based on :attr:`mode`. Appropriate arguments
           (e.g., :attr:`inputs`, :attr:`start_tokens`, etc) are selected to
           construct the helper. Additional arguments for helper constructor
           can be provided either through :attr:`**kwargs`, or through
           ``hparams["helper_train/infer"]["kwargs"]``.

           This means is used only when both :attr:`decoding_strategy` and
           :attr:`helper` are ``None``.

           Example:

             .. code-block:: python

                h = {
                    "helper_infer": {
                        "type": "GumbelSoftmaxEmbeddingHelper",
                        "kwargs": { "tau": 0.1 }
                    }
                }
                embedder = WordEmbedder(vocab_size=data.vocab.size)
                decoder = BasicRNNDecoder(vocab_size=data.vocab.size, hparams=h)

                # Gumbel-softmax decoding
                decoder.eval()  # disable dropout
                output, _, _ = decoder(
                    decoding_strategy=None, # Sets to None explicit
                    embedding=embedder,
                    start_tokens=[data.vocab.bos_token_id] * 100,
                    end_token=data.vocab.eos_token_id,
                    max_decoding_length=60)

        Args:
            decoding_strategy (str): A string specifying the decoding
                strategy. Different arguments are required based on the
                strategy.
                Ignored if :attr:`helper` is given.
            embedding (optional): A callable that returns embedding vectors
                of `inputs` (e.g., an instance of subclass of
                :class:`~texar.modules.EmbedderBase`), or the `params`
                argument of
                :tf_main:`tf.nn.embedding_lookup <nn/embedding_lookup>`.
                If provided, `inputs` (if used) will be passed to
                `embedding` to fetch the embedding vectors of the inputs.
                Required when :attr:`decoding_strategy` is ``"infer_greedy"``
                or ``"infer_sample"``; optional when
                ``decoding_strategy="train_greedy"``.
            start_tokens (optional): A :tensor:`LongTensor` of shape
                ``[batch_size]``, the start tokens.
                Used when :attr:`decoding_strategy` is ``"infer_greedy"`` or
                ``"infer_sample"``, or when `hparams`-configured
                helper is used.
                When used with the Texar data module, to get ``batch_size``
                samples where ``batch_size`` is changing according to the data
                module, this can be set as
                ``start_tokens=torch.full_like(batch['length'], bos_token_id)``.
            end_token (optional): A integer or 0D :tensor:`LongTensor`, the
                token that marks the end of decoding.
                Used when :attr:`decoding_strategy` is ``"infer_greedy"`` or
                ``"infer_sample"``, or when `hparams`-configured helper is used.
            softmax_temperature (float, optional): Value to divide the logits
                by before computing the softmax. Larger values (above 1.0)
                result in more random samples. Must be > 0. If ``None``, 1.0 is
                used. Used when ``decoding_strategy="infer_sample"``.
            infer_mode (optional): If not ``None``, overrides mode given by
                :attr:`self.training`.
            **kwargs: Other keyword arguments for constructing helpers
                defined by ``hparams["helper_train"]`` or
                ``hparams["helper_infer"]``.

        Returns:
            The constructed helper instance.
        """
        if decoding_strategy is not None:
            if decoding_strategy == 'train_greedy':
                helper: Helper = decoder_helpers.TrainingHelper(
                    embedding, self._input_time_major)
            elif decoding_strategy in ['infer_greedy', 'infer_sample']:
                if (embedding is None or start_tokens is None
                        or end_token is None):
                    raise ValueError(
                        f"When using '{decoding_strategy}' decoding strategy, "
                        f"'embedding', 'start_tokens', and 'end_token' must "
                        f"not be `None`.")
                if decoding_strategy == 'infer_greedy':
                    helper = decoder_helpers.GreedyEmbeddingHelper(
                        embedding, start_tokens, end_token)
                else:
                    helper = decoder_helpers.SampleEmbeddingHelper(
                        embedding, start_tokens, end_token,
                        softmax_temperature)
            else:
                raise ValueError(
                    f"Unknown decoding strategy: {decoding_strategy}")
        else:
            is_training = (not infer_mode
                           if infer_mode is not None else self.training)
            if is_training:
                kwargs_ = copy.copy(self._hparams.helper_train.kwargs.todict())
                helper_type = self._hparams.helper_train.type
            else:
                kwargs_ = copy.copy(self._hparams.helper_infer.kwargs.todict())
                helper_type = self._hparams.helper_infer.type
            kwargs_.update({
                'time_major': self._input_time_major,
                'embedding': embedding,
                'start_tokens': start_tokens,
                'end_token': end_token,
                'softmax_temperature': softmax_temperature
            })
            kwargs_.update(kwargs)
            helper = decoder_helpers.get_helper(helper_type, **kwargs_)
        return helper