示例#1
0
    def test_decode_infer_sample(self):
        r"""Tests infer_sample
        """
        hparams = {
            "pretrained_model_name": None,
        }
        decoder = GPT2Decoder(hparams=hparams)
        decoder.eval()

        start_tokens = torch.full((self.batch_size, ), 1, dtype=torch.int64)
        end_token = 2

        helper = decoder_helpers.SampleEmbeddingHelper(start_tokens, end_token)

        outputs, length = decoder(helper=helper,
                                  max_decoding_length=self.max_length)

        self.assertIsInstance(outputs, TransformerDecoderOutput)
示例#2
0
    def test_decode_infer_sample(self):
        """Tests infer_sample
        """
        decoder = TransformerDecoder(vocab_size=self._vocab_size,
                                     output_layer=self._output_layer)
        decoder.eval()
        helper = decoder_helpers.SampleEmbeddingHelper(self._embedding_fn,
                                                       self._start_tokens,
                                                       self._end_token)

        outputs, length = decoder(
            memory=self._memory,
            memory_sequence_length=self._memory_sequence_length,
            memory_attention_bias=None,
            inputs=None,
            helper=helper,
            max_decoding_length=self._max_decode_len)

        self.assertIsInstance(outputs, TransformerDecoderOutput)
    def test_decode_infer_sample(self):
        r"""Tests infer_sample
        """
        decoder = GPT2Decoder()
        decoder.eval()

        start_tokens = torch.full((16, ), 1, dtype=torch.int64)
        end_token = 2
        max_decoding_length = 16

        embedding_fn = lambda x, y: (decoder.word_embedder(x) + decoder.
                                     position_embedder(y))

        helper = decoder_helpers.SampleEmbeddingHelper(embedding_fn,
                                                       start_tokens, end_token)

        outputs, length = decoder(helper=helper,
                                  max_decoding_length=max_decoding_length)

        self.assertIsInstance(outputs, TransformerDecoderOutput)
示例#4
0
    def create_helper(self,
                      *,
                      decoding_strategy: Optional[str] = None,
                      embedding: Optional[Embedding] = None,
                      start_tokens: Optional[torch.LongTensor] = None,
                      end_token: Optional[int] = None,
                      softmax_temperature: Optional[float] = None,
                      infer_mode: Optional[bool] = None,
                      **kwargs) -> Helper:
        r"""Create a helper instance for the decoder. This is a shared interface
        for both :class:`~texar.modules.BasicRNNDecoder` and
        :class:`~texar.modules.AttentionRNNDecoder`.

        The function provides **3 ways** to specify the
        decoding method, with varying flexibility:

        1. The :attr:`decoding_strategy` argument: A string taking value of:

            - **"train_greedy"**: decoding in teacher-forcing fashion (i.e.,
              feeding `ground truth` to decode the next step), and each sample
              is obtained by taking the `argmax` of the output logits.
              Arguments :attr:`(inputs, sequence_length)`
              are required for this strategy, and argument :attr:`embedding`
              is optional.
            - **"infer_greedy"**: decoding in inference fashion (i.e., feeding
              the `generated` sample to decode the next step), and each sample
              is obtained by taking the `argmax` of the output logits.
              Arguments :attr:`(embedding, start_tokens, end_token)` are
              required for this strategy, and argument
              :attr:`max_decoding_length` is optional.
            - **"infer_sample"**: decoding in inference fashion, and each
              sample is obtained by `random sampling` from the RNN output
              distribution. Arguments
              :attr:`(embedding, start_tokens, end_token)` are
              required for this strategy, and argument
              :attr:`max_decoding_length` is optional.

          This argument is used only when argument :attr:`helper` is `None`.

          Example:

            .. code-block:: python

                embedder = WordEmbedder(vocab_size=data.vocab.size)
                decoder = BasicRNNDecoder(vocab_size=data.vocab.size)

                # Teacher-forcing decoding
                outputs_1, _, _ = decoder(
                    decoding_strategy='train_greedy',
                    inputs=embedder(data_batch['text_ids']),
                    sequence_length=data_batch['length'] - 1)

                # Random sample decoding. Gets 100 sequence samples
                outputs_2, _, sequence_length = decoder(
                    decoding_strategy='infer_sample',
                    start_tokens=[data.vocab.bos_token_id] * 100,
                    end_token=data.vocab.eos.token_id,
                    embedding=embedder,
                    max_decoding_length=60)

        2. The :attr:`helper` argument: An instance of subclass of
           :class:`~texar.modules.decoders.rnn_decoder_helpers.Helper`. This
           provides a superset of decoding strategies than above, for example:

            - :class:`~texar.modules.TrainingHelper` corresponding to the
              "train_greedy" strategy.
            - :class:`~texar.modules.ScheduledEmbeddingTrainingHelper` and
              :class:`~texar.modules.ScheduledOutputTrainingHelper` for
              scheduled sampling.
            - :class:`~texar.modules.SoftmaxEmbeddingHelper` and
              :class:`~texar.modules.GumbelSoftmaxEmbeddingHelper` for
              soft decoding and gradient backpropagation.

          This means gives the maximal flexibility of configuring the decoding
          strategy.

          Example:

            .. code-block:: python

                embedder = WordEmbedder(vocab_size=data.vocab.size)
                decoder = BasicRNNDecoder(vocab_size=data.vocab.size)

                # Teacher-forcing decoding, same as above with
                # `decoding_strategy='train_greedy'`
                helper_1 = TrainingHelper(
                    inputs=embedders(data_batch['text_ids']),
                    sequence_length=data_batch['length'] - 1)
                outputs_1, _, _ = decoder(helper=helper_1)

                # Gumbel-softmax decoding
                helper_2 = GumbelSoftmaxEmbeddingHelper(
                    embedding=embedder,
                    start_tokens=[data.vocab.bos_token_id] * 100,
                    end_token=data.vocab.eos_token_id,
                    tau=0.1)
                outputs_2, _, sequence_length = decoder(
                    max_decoding_length=60, helper=helper_2)

        3. ``hparams["helper_train"]`` and ``hparams["helper_infer"]``:
           Specifying the helper through hyperparameters. Train and infer
           strategy is toggled based on :attr:`mode`. Appropriate arguments
           (e.g., :attr:`inputs`, :attr:`start_tokens`, etc) are selected to
           construct the helper. Additional arguments for helper constructor
           can be provided either through :attr:`**kwargs`, or through
           ``hparams["helper_train/infer"]["kwargs"]``.

           This means is used only when both :attr:`decoding_strategy` and
           :attr:`helper` are ``None``.

           Example:

             .. code-block:: python

                h = {
                    "helper_infer": {
                        "type": "GumbelSoftmaxEmbeddingHelper",
                        "kwargs": { "tau": 0.1 }
                    }
                }
                embedder = WordEmbedder(vocab_size=data.vocab.size)
                decoder = BasicRNNDecoder(vocab_size=data.vocab.size, hparams=h)

                # Gumbel-softmax decoding
                decoder.eval()  # disable dropout
                output, _, _ = decoder(
                    decoding_strategy=None, # Sets to None explicit
                    embedding=embedder,
                    start_tokens=[data.vocab.bos_token_id] * 100,
                    end_token=data.vocab.eos_token_id,
                    max_decoding_length=60)

        Args:
            decoding_strategy (str): A string specifying the decoding
                strategy. Different arguments are required based on the
                strategy.
                Ignored if :attr:`helper` is given.
            embedding (optional): A callable that returns embedding vectors
                of `inputs` (e.g., an instance of subclass of
                :class:`~texar.modules.EmbedderBase`), or the `params`
                argument of
                :tf_main:`tf.nn.embedding_lookup <nn/embedding_lookup>`.
                If provided, `inputs` (if used) will be passed to
                `embedding` to fetch the embedding vectors of the inputs.
                Required when :attr:`decoding_strategy` is ``"infer_greedy"``
                or ``"infer_sample"``; optional when
                ``decoding_strategy="train_greedy"``.
            start_tokens (optional): A :tensor:`LongTensor` of shape
                ``[batch_size]``, the start tokens.
                Used when :attr:`decoding_strategy` is ``"infer_greedy"`` or
                ``"infer_sample"``, or when `hparams`-configured
                helper is used.
                When used with the Texar data module, to get ``batch_size``
                samples where ``batch_size`` is changing according to the data
                module, this can be set as
                ``start_tokens=torch.full_like(batch['length'], bos_token_id)``.
            end_token (optional): A integer or 0D :tensor:`LongTensor`, the
                token that marks the end of decoding.
                Used when :attr:`decoding_strategy` is ``"infer_greedy"`` or
                ``"infer_sample"``, or when `hparams`-configured helper is used.
            softmax_temperature (float, optional): Value to divide the logits
                by before computing the softmax. Larger values (above 1.0)
                result in more random samples. Must be > 0. If ``None``, 1.0 is
                used. Used when ``decoding_strategy="infer_sample"``.
            infer_mode (optional): If not ``None``, overrides mode given by
                :attr:`self.training`.
            **kwargs: Other keyword arguments for constructing helpers
                defined by ``hparams["helper_train"]`` or
                ``hparams["helper_infer"]``.

        Returns:
            The constructed helper instance.
        """
        if decoding_strategy is not None:
            if decoding_strategy == 'train_greedy':
                helper: Helper = decoder_helpers.TrainingHelper(
                    embedding, self._input_time_major)
            elif decoding_strategy in ['infer_greedy', 'infer_sample']:
                if (embedding is None or start_tokens is None
                        or end_token is None):
                    raise ValueError(
                        f"When using '{decoding_strategy}' decoding strategy, "
                        f"'embedding', 'start_tokens', and 'end_token' must "
                        f"not be `None`.")
                if decoding_strategy == 'infer_greedy':
                    helper = decoder_helpers.GreedyEmbeddingHelper(
                        embedding, start_tokens, end_token)
                else:
                    helper = decoder_helpers.SampleEmbeddingHelper(
                        embedding, start_tokens, end_token,
                        softmax_temperature)
            else:
                raise ValueError(
                    f"Unknown decoding strategy: {decoding_strategy}")
        else:
            is_training = (not infer_mode
                           if infer_mode is not None else self.training)
            if is_training:
                kwargs_ = copy.copy(self._hparams.helper_train.kwargs.todict())
                helper_type = self._hparams.helper_train.type
            else:
                kwargs_ = copy.copy(self._hparams.helper_infer.kwargs.todict())
                helper_type = self._hparams.helper_infer.type
            kwargs_.update({
                'time_major': self._input_time_major,
                'embedding': embedding,
                'start_tokens': start_tokens,
                'end_token': end_token,
                'softmax_temperature': softmax_temperature
            })
            kwargs_.update(kwargs)
            helper = decoder_helpers.get_helper(helper_type, **kwargs_)
        return helper