Esempio n. 1
0
    def _test_sequence_loss(self, loss_fn, actions, logits, advantages, batched,
                            sequence_length):
        loss = loss_fn(actions, logits, advantages, batched=batched,
                       sequence_length=sequence_length)
        rank = get_rank(loss)
        self.assertEqual(rank, 0)

        loss = loss_fn(actions, logits, advantages, batched=batched,
                       sequence_length=sequence_length,
                       sum_over_timesteps=False)
        rank = get_rank(loss)
        self.assertEqual(rank, 1)
        self.assertEqual(loss.shape, torch.Size([self._max_time]))

        loss = loss_fn(actions, logits, advantages, batched=batched,
                       sequence_length=sequence_length,
                       sum_over_timesteps=False,
                       average_across_timesteps=True,
                       average_across_batch=False)
        rank = get_rank(loss)
        if batched:
            self.assertEqual(rank, 1)
            self.assertEqual(loss.shape, torch.Size([self._batch_size]))
        else:
            self.assertEqual(rank, 0)

        loss = loss_fn(actions, logits, advantages, batched=batched,
                       sequence_length=sequence_length,
                       sum_over_timesteps=False,
                       average_across_batch=False)
        rank = get_rank(loss)
        if batched:
            self.assertEqual(rank, 2)
            self.assertEqual(loss.shape,
                             torch.Size([self._batch_size, self._max_time]))
        else:
            self.assertEqual(rank, 1)
            self.assertEqual(loss.shape,
                             torch.Size([self._max_time]))

        sequence_length_time = torch.randint(size=(self._max_time,),
                                             high=self._batch_size)
        loss = loss_fn(actions, logits, advantages, batched=batched,
                       sequence_length=sequence_length_time,
                       sum_over_timesteps=False,
                       average_across_batch=False,
                       time_major=True)
        if batched:
            self.assertEqual(loss.shape, torch.Size([self._batch_size,
                                                     self._max_time]))
        else:
            self.assertEqual(loss.shape, torch.Size([self._max_time]))
Esempio n. 2
0
    def test_sequence_sigmoid_cross_entropy(self):
        """Tests `texar.losses.sequence_sigmoid_cross_entropy`.
        """
        self._test_sequence_loss(tx.losses.sequence_sigmoid_cross_entropy,
                                 self._one_hot_labels, self._logits,
                                 self._sequence_length)

        self._test_sequence_loss(tx.losses.sequence_sigmoid_cross_entropy,
                                 self._one_hot_labels[:, :, 0],
                                 self._logits[:, :, 0], self._sequence_length)

        loss = tx.losses.sequence_sigmoid_cross_entropy(
            logits=self._logits[:, :, 0],
            labels=torch.ones([self._batch_size, self._max_time]),
            sequence_length=self._sequence_length)
        rank = get_rank(loss)
        self.assertEqual(rank, 0)
Esempio n. 3
0
    def _test_entropy(self, entropy_fn, logits, sequence_length=None):
        if sequence_length is None:
            entropy = entropy_fn(logits)
            rank = get_rank(entropy)
            self.assertEqual(rank, 0)

            entropy = entropy_fn(logits, average_across_batch=False)
            rank = get_rank(entropy)
            self.assertEqual(rank, 1)
            self.assertEqual(entropy.shape, torch.Size([self._batch_size]))
        else:
            entropy = entropy_fn(logits, sequence_length=sequence_length)
            rank = get_rank(entropy)
            self.assertEqual(rank, 0)

            entropy = entropy_fn(logits,
                                 sequence_length=sequence_length,
                                 sum_over_timesteps=False)
            rank = get_rank(entropy)
            self.assertEqual(rank, 1)
            self.assertEqual(entropy.shape, torch.Size([self._max_time]))

            entropy = entropy_fn(logits,
                                 sequence_length=sequence_length,
                                 sum_over_timesteps=False,
                                 average_across_timesteps=True,
                                 average_across_batch=False)
            rank = get_rank(entropy)
            self.assertEqual(rank, 1)
            self.assertEqual(entropy.shape, torch.Size([self._batch_size]))

            entropy = entropy_fn(logits,
                                 sequence_length=sequence_length,
                                 sum_over_timesteps=False,
                                 average_across_batch=False)
            rank = get_rank(entropy)
            self.assertEqual(rank, 2)
            self.assertEqual(entropy.shape,
                             torch.Size([self._batch_size, self._max_time]))

            sequence_length_time = torch.randint(size=(self._max_time, ),
                                                 high=self._batch_size)
            entropy = entropy_fn(logits,
                                 sequence_length=sequence_length_time,
                                 sum_over_timesteps=False,
                                 average_across_batch=False,
                                 time_major=True)
            self.assertEqual(entropy.shape,
                             torch.Size([self._batch_size, self._max_time]))
    def _build(
            self,  # pylint: disable=arguments-differ, too-many-statements
            decoding_strategy='train_greedy',
            inputs=None,
            memory=None,
            memory_sequence_length=None,
            memory_attention_bias=None,
            beam_width=None,
            length_penalty=0.,
            start_tokens=None,
            end_token=None,
            context=None,
            context_sequence_length=None,
            softmax_temperature=None,
            max_decoding_length=None,
            impute_finished=False,
            embedding=None,
            helper=None,
            mode=None):
        """Performs decoding.

        The interface is mostly the same with that of RNN decoders
        (see :meth:`~texar.modules.RNNDecoderBase._build`). The main difference
        is that, here, `sequence_length` is not needed, and continuation
        generation is additionally supported.

        The function provides **3 ways** to specify the decoding method, with
        varying flexibility:

        1. The :attr:`decoding_strategy` argument.

            - **"train_greedy"**: decoding in teacher-forcing fashion (i.e.,
              feeding ground truth to decode the next step), and for each step
              sample is obtained by taking the `argmax` of logits.
              Argument :attr:`inputs` is required for this strategy.
            - **"infer_greedy"**: decoding in inference fashion (i.e., feeding
              `generated` sample to decode the next step), and for each step
              sample is obtained by taking the `argmax` of logits.
              Arguments :attr:`(start_tokens, end_token)` are
              required for this strategy, and argument
              :attr:`max_decoding_length` is optional.
            - **"infer_sample"**: decoding in inference fashion, and for each
              step sample is obtained by `random sampling` from the logits.
              Arguments :attr:`(start_tokens, end_token)` are required for this
              strategy, and argument :attr:`max_decoding_length` is optional.

          This argument is used only when arguments :attr:`helper` and
          :attr:`beam_width` are both `None`.

        2. The :attr:`helper` argument: An instance of subclass of
           :class:`texar.modules.Helper`.
           This provides a superset of decoding strategies than above.
           The interface is the same as in RNN decoders.
           Please refer to :meth:`texar.modules.RNNDecoderBase._build` for
           detailed usage and examples.

           Note that, here, though using a
           :class:`~texar.modules.TrainingHelper` corresponds to the
           "train_greedy" strategy above and will get the same output results,
           the implementation is *slower* than
           directly setting `decoding_strategy="train_greedy"`.

           Argument :attr:`max_decoding_length` is optional.

        3. **Beam search**: set :attr:`beam_width` to use beam search decoding.
           Arguments :attr:`(start_tokens, end_token)` are required,
           and argument :attr:`max_decoding_length` is optional.

        Args:
            memory (optional): The memory to attend, e.g., the output of an RNN
                encoder. A Tensor of shape `[batch_size, memory_max_time, dim]`.
            memory_sequence_length (optional): A Tensor of shape `[batch_size]`
                containing the sequence lengths for the batch entries in
                memory. Used to create attention bias of
                :attr:`memory_attention_bias` is not given. Ignored if
                `memory_attention_bias` is provided.
            memory_attention_bias (optional): A Tensor of shape
                `[batch_size, num_heads, memory_max_time, dim]`.
                An attention bias typically sets the value of a padding
                position to a large negative value for masking. If not given,
                :attr:`memory_sequence_length` is used to automatically
                create an attention bias.
            inputs (optional): Input tensor for teacher forcing decoding, of
                shape `[batch_size, target_max_time, emb_dim]` containing the
                target sequence word embeddings.
                Used when :attr:`decoding_strategy` is set to "train_greedy".
            decoding_strategy (str): A string specifying the decoding
                strategy, including "train_greedy", "infer_greedy",
                "infer_sample".
                Different arguments are required based on the
                strategy. See above for details. Ignored if
                :attr:`beam_width` or :attr:`helper` is set.
            beam_width (int): Set to use beam search. If given,
                :attr:`decoding_strategy` is ignored.
            length_penalty (float): Length penalty coefficient used in beam
                search decoding. Refer to https://arxiv.org/abs/1609.08144
                for more details.
                It Should be larger if longer sentences are wanted.
            start_tokens (optional): An int Tensor of shape `[batch_size]`,
                containing the start tokens.
                Used when :attr:`decoding_strategy` = "infer_greedy" or
                "infer_sample", or :attr:`beam_width` is set.
                Ignored if :attr:`context` is given.
            end_token (optional): An int 0D Tensor, the token that marks end
                of decoding.
                Used when :attr:`decoding_strategy` = "infer_greedy" or
                "infer_sample", or :attr:`beam_width` is set.
            context (optional): An int Tensor of shape `[batch_size, length]`,
                containing the starting tokens for decoding.
                If context is set, :attr:`start_tokens` will be ignored.
            context_sequence_length (optional): specify the length of context.
            softmax_temperature (optional): A float 0D Tensor, value to divide
                the logits by before computing the softmax. Larger values
                (above 1.0) result in more random samples. Must > 0. If `None`,
                1.0 is used.
                Used when :attr:`decoding_strategy` = "infer_sample"`.
            max_decoding_length (optional): An int scalar Tensor indicating
                the maximum allowed number of decoding steps.
                If `None` (default), use "max_decoding_length" defined in
                :attr:`hparams`. Ignored in "train_greedy" decoding.
            impute_finished (bool): If `True`, then states for batch
                entries which are marked as finished get copied through and
                the corresponding outputs get zeroed out.  This causes some
                slowdown at each time step, but ensures that the final state
                and outputs have the correct values and that backprop ignores
                time steps that were marked as finished. Ignored in
                "train_greedy" decoding.
            embedding (optional): Embedding used when
                "infer_greedy" or "infer_sample" `decoding_strategy`, or
                beam search, is used. This can be
                a callable or the `params` argument for
                :tf_main:`embedding_lookup <nn/embedding_lookup>`.
                If a callable, it can take a vector tensor of token `ids`,
                or take two arguments (`ids`, `times`), where `ids`
                is a vector tensor of token ids, and `times` is a vector tensor
                of time steps (i.e., position ids). The latter case can be used
                when attr:`embedding` is a combination of word embedding and
                position embedding.
            helper (optional): An instance of
                :tf_main:`Helper <contrib/seq2seq/Helper>` that defines the
                decoding strategy. If given, :attr:`decoding_strategy` is
                ignored.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including
                `TRAIN`, `EVAL`, and `PREDICT`. Controls dropout mode.
                If `None` (default), :func:`texar.global_mode`
                is used.

        Returns:

            - For **"train_greedy"** decoding, returns an instance of \
            :class:`~texar.modules.TransformerDecoderOutput` which contains\
            `sample_id` and `logits`.

            - For **"infer_greedy"** and **"infer_sample"** decoding or\
            decoding with :attr:`helper`, returns\
            a tuple `(outputs, sequence_lengths)`, where `outputs` is an \
            instance of :class:`~texar.modules.TransformerDecoderOutput` as\
            in "train_greedy", and `sequence_lengths` is a Tensor of shape\
            `[batch_size]` containing the length of each sample.

            - For **beam search** decoding, returns a `dict` containing keys\
            "sample_id" and "log_prob".

                - **"sample_id"** is an int Tensor of shape \
                `[batch_size, max_time, beam_width]` containing generated\
                token indexes. `sample_id[:,:,0]` is the highest-probable \
                sample.
                - **"log_prob"** is a float Tensor of shape \
                `[batch_size, beam_width]` containing the log probability \
                of each sequence sample.
        """

        if memory is not None:
            if memory_attention_bias is None:
                if memory_sequence_length is None:
                    raise ValueError("`memory_sequence_length` is required if "
                                     "`memory_attention_bias` is not given.")

                enc_padding = 1 - tf.sequence_mask(memory_sequence_length,
                                                   shape_list(memory)[1],
                                                   dtype=tf.float32)
                memory_attention_bias = attn.attention_bias_ignore_padding(
                    enc_padding)

        # context will be used in step function for dynamic_decode
        if context is not None:
            start_tokens = context[:, 0]
            self.context = context[:, 1:]
            self.context_sequence_length = context_sequence_length - 1
        else:
            self.context = None

        self.embedding = embedding

        if helper is None and beam_width is None and \
                decoding_strategy == 'train_greedy':  # Teacher-forcing

            decoder_self_attention_bias = (attn.attention_bias_lower_triangle(
                shape_list(inputs)[1]))

            decoder_output = self._self_attention_stack(
                inputs,
                memory,
                decoder_self_attention_bias=decoder_self_attention_bias,
                memory_attention_bias=memory_attention_bias,
                cache=None,
                mode=mode)
            logits = self._output_layer(decoder_output)
            preds = tf.to_int32(tf.argmax(logits, axis=-1))
            rets = TransformerDecoderOutput(logits=logits, sample_id=preds)

        else:
            if max_decoding_length is None:
                max_decoding_length = self._hparams.max_decoding_length
            self.max_decoding_length = max_decoding_length
            if beam_width is None:  # Inference-like decoding
                # Prepare helper
                if helper is None:
                    if decoding_strategy == "infer_greedy":
                        helper = tx_helper.GreedyEmbeddingHelper(
                            embedding, start_tokens, end_token)
                    elif decoding_strategy == "infer_sample":
                        helper = tx_helper.SampleEmbeddingHelper(
                            embedding, start_tokens, end_token,
                            softmax_temperature)
                    else:
                        raise ValueError(
                            "Unknown decoding strategy: {}".format(
                                decoding_strategy))
                self._helper = helper

                self._cache = self._init_cache(memory,
                                               memory_attention_bias,
                                               beam_search_decoding=False)

                if context is not None:  # To avoid out-of-range in `step`
                    paddings = [[0, 0] for _ in range(get_rank(self.context))]
                    paddings[1][1] = \
                        max_decoding_length - shape_list(self.context)[1]
                    self.context = tf.pad(self.context, paddings=paddings)

                outputs, _, sequence_lengths = dynamic_decode(
                    decoder=self,
                    impute_finished=impute_finished,
                    maximum_iterations=max_decoding_length,
                    output_time_major=False,
                    scope=self.variable_scope)

                if context is not None:
                    # Here the length of sample_id will be larger than that
                    # of logit by 1, because there will be a additional
                    # start_token in the returned sample_id.
                    # the start_id should be the first token of the
                    # given context
                    outputs = TransformerDecoderOutput(
                        logits=outputs.logits,
                        sample_id=tf.concat([
                            tf.expand_dims(start_tokens, 1), outputs.sample_id
                        ],
                                            axis=1))
                    sequence_lengths = sequence_lengths + 1
                rets = outputs, sequence_lengths

            else:  # Beam-search decoding
                # Ignore `decoding_strategy`; Assume `helper` is not set
                if helper is not None:
                    raise ValueError("Must not set 'beam_width' and 'helper' "
                                     "simultaneously.")
                _batch_size = shape_list(start_tokens)[0]
                self._cache = self._init_cache(memory,
                                               memory_attention_bias,
                                               beam_search_decoding=True,
                                               batch_size=_batch_size)

                # The output format is different when running beam search
                sample_id, log_prob = self._beam_decode(
                    start_tokens,
                    end_token,
                    beam_width=beam_width,
                    length_penalty=length_penalty,
                    decode_length=max_decoding_length,
                )
                rets = {'sample_id': sample_id, 'log_prob': log_prob}

        if not self._built:
            self._add_internal_trainable_variables()
            self._built = True

        return rets
Esempio n. 5
0
    def _build(self,
               ids=None,
               soft_ids=None,
               stop_gradient=False,
               mode=None,
               **kwargs):
        """Embeds (soft) ids.

        Either :attr:`ids` or :attr:`soft_ids` must be given, and they
        must not be given at the same time.

        Args:
            ids (optional): An integer tensor containing the ids to embed.
            soft_ids (optional): A tensor of weights (probabilities) used to
                mix the embedding vectors.
            stop_gradient (bool): Whether to stop gradient back-propagation
                to the embedding tensor. This can be used when, e.g., updating
                only `soft_ids` while keeping the embedding tensor fixed.
                Default to `False`.
            mode (optional): A tensor taking value in
                :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including
                `TRAIN`, `EVAL`, and `PREDICT`. If `None`, dropout is
                controlled by :func:`texar.global_mode`.
            kwargs: Additional keyword arguments for
                :tf_main:`tf.nn.embedding_lookup <nn/embedding_lookup>` besides
                :attr:`params` and :attr:`ids`.

        Returns:
            If :attr:`ids` is given, returns a Tensor of shape
            `shape(ids) + embedding-dim`. For example,
            if `shape(ids) = [batch_size, max_time]`
            and `shape(embedding) = [vocab_size, emb_dim]`, then the return
            tensor has shape `[batch_size, max_time, emb_dim]`.

            If :attr:`soft_ids` is given, returns a Tensor of shape
            `shape(soft_ids)[:-1] + embdding-dim`. For example,
            if `shape(soft_ids) = [batch_size, max_time, vocab_size]`
            and `shape(embedding) = [vocab_size, emb_dim]`, then the return
            tensor has shape `[batch_size, max_time, emb_dim]`.
        """
        if ids is not None:
            if soft_ids is not None:
                raise ValueError(
                    'Must not specify `ids` and `soft_ids` at the same time.')
            ids_rank = get_rank(ids)
        elif soft_ids is not None:
            ids_rank = get_rank(soft_ids) - 1
        else:
            raise ValueError('Either `ids` or `soft_ids` must be given.')

        embedding = self._embedding
        if stop_gradient:
            embedding = tf.stop_gradient(embedding)

        is_training = is_train_mode(mode)
        if self._hparams.dropout_strategy == 'item_type':
            dropout_layer = self._get_dropout_layer(self._hparams)
            if dropout_layer:
                embedding = dropout_layer.apply(inputs=embedding,
                                                training=is_training)

        if ids is not None:
            outputs = tf.nn.embedding_lookup(embedding, ids, **kwargs)
        else:
            outputs = embedder_utils.soft_embedding_lookup(embedding, soft_ids)

        if self._hparams.dropout_strategy != 'item_type':
            dropout_layer = self._get_dropout_layer(self._hparams,
                                                    ids_rank=ids_rank,
                                                    dropout_input=outputs)
            if dropout_layer:
                outputs = dropout_layer.apply(inputs=outputs,
                                              training=is_training)

        return outputs
Esempio n. 6
0
def pg_loss_with_log_probs(log_probs,
                           advantages,
                           rank=None,
                           batched=False,
                           sequence_length=None,
                           average_across_batch=True,
                           average_across_timesteps=False,
                           average_across_remaining=False,
                           sum_over_batch=False,
                           sum_over_timesteps=True,
                           sum_over_remaining=True,
                           time_major=False):
    """Policy gradient loss with log probs of actions.

    `pg_loss = reduce( advantages * -log_probs )`,
    where `advantages` does not back-propagate gradients.

    All arguments except :attr:`log_probs` are the same as
    :func:`pg_loss_with_logits`.

    Args:
        log_probs: Log probabilities of shape
            `[(batch_size,) max_time, ..., d_rank]` and dtype `float32`
            or `float64`. The rank of the Tensor is specified
            with :attr:`rank`.

            The batch dimension exists only if :attr:`batched` is `True`.

            The batch and time dimensions are exchanged, i.e.,
            `[max_time, batch_size, ...]` if :attr:`time_major` is `True`.
        advantages: Tensor of shape
            `[(batch_size,) max_time, d_3, ..., d_rank]` and
            dtype `float32` or `float64`.
            The batch dimension exists only if `batched` is `True`.
            The batch and time dimensions
            are exchanged if `time_major` is `True`.
        rank (int, optional): The rank of :attr:`log_probs`.
            If `None` (default), rank is automatically inferred from
            `log_probs` or `advantages`. If the inference fails,
            `rank` is set to 1 if `batched``==False`,
            and set to 2 if `batched``==True`.
        batched (bool): `True` if the inputs are batched.
        sequence_length (optional): A Tensor of shape `[batch_size]`.
            Time steps beyond the respective sequence lengths will have zero
            losses. Used if :attr:`batched` is `True`.
        average_across_timesteps (bool): If set, average the loss across
            the time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        average_across_batch (bool): If set, average the loss across the
            batch dimension. Must not set `average_across_batch`'
            and `sum_over_batch` at the same time.
            Ignored if `batched` is `False`.
        average_across_remaining (bool): If set, average the sequence across the
            remaining dimensions. Must not set `average_across_remaining`'
            and `sum_over_remaining` at the same time. Ignored if
            no more dimensions other than the batch and time dimensions.
        sum_over_timesteps (bool): If set, sum the loss across the
            time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        sum_over_batch (bool): If set, sum the loss across the
            batch dimension. Must not set `average_across_batch`
            and `sum_over_batch` at the same time.
            Ignored if `batched` is `False`.
        sum_over_remaining (bool): If set, sum the loss across the
            remaining dimension. Must not set `average_across_remaining`
            and `sum_over_remaining` at the same time. Ignored if
            no more dimensions other than the batch and time dimensions.
        time_major (bool): The shape format of the inputs. If `True`,
            :attr:`log_probs` and :attr:`advantages` must have shape
            `[max_time, batch_size, ...]`. If `False` (default),
            they must have shape `[batch_size, max_time, ...]`.
            Ignored if :attr:`batched` is `False`.

    Returns:
        A Tensor containing the loss to minimize, whose rank depends on the
        reduce arguments. For example, the batch dimension is reduced if
        either :attr:`average_across_batch` or :attr:`sum_over_batch` is
        `True`, which decreases the rank of output tensor by 1.
    """
    advantages = tf.stop_gradient(advantages)

    losses = -log_probs * advantages

    if rank is None:
        rank = get_rank(log_probs) or get_rank(advantages)
    if rank is None:
        rank = 2 if batched else 1

    if batched:
        losses = mask_and_reduce(
            losses,
            sequence_length,
            rank=rank,
            average_across_batch=average_across_batch,
            average_across_timesteps=average_across_timesteps,
            average_across_remaining=average_across_remaining,
            sum_over_batch=sum_over_batch,
            sum_over_timesteps=sum_over_timesteps,
            sum_over_remaining=sum_over_remaining,
            time_major=time_major)
    elif rank > 1:
        if average_across_remaining and sum_over_remaining:
            raise ValueError("Only one of `average_across_remaining` and "
                             "`sum_over_remaining` can be set.")
        if average_across_remaining:
            losses = tf.reduce_mean(losses, axis=list(range(1, rank)))
        elif sum_over_remaining:
            losses = tf.reduce_sum(losses, axis=list(range(1, rank)))

    if not batched:
        if average_across_timesteps and sum_over_timesteps:
            raise ValueError("Only one of `average_across_timesteps` and "
                             "`sum_over_timesteps` can be set.")
        if average_across_timesteps:
            losses = tf.reduce_mean(losses, axis=0)
        elif sum_over_timesteps:
            losses = tf.reduce_sum(losses, axis=0)

    return losses
Esempio n. 7
0
def sequence_entropy_with_logits(logits,
                                 rank=None,
                                 sequence_length=None,
                                 average_across_batch=True,
                                 average_across_timesteps=False,
                                 average_across_remaining=False,
                                 sum_over_batch=False,
                                 sum_over_timesteps=True,
                                 sum_over_remaining=True,
                                 time_major=False):
    """Shannon entropy given logits.

    Args:
        logits: Unscaled log probabilities of shape
            `[batch_size, max_time, d_3, ..., d_{rank-1}, distribution_dim]`
            and of dtype `float32` or `float64`.

            The rank of the tensor is optionally specified by the argument
            :attr:`rank`.

            The tensor is considered as having `[batch_size, .., d_{rank-1}]`
            elements, each of which has a distribution of length `d_rank`
            (i.e., `distribution_dim`). So the last dimension is always
            summed out to compute the entropy.

            The batch and time dimensions are exchanged if :attr:`time_major`
            is `True`.
        rank (int, optional): The rank of :attr:`logits`.
            If `None` (default), :attr:`rank` is inferred automatically from
            :attr:`logits`. If the inferred rank is `None`, :attr:`rank` is
            set to 3, i.e., assuming :attr:`logits` is of shape
            `[batch_size, max_time, distribution_dim]`
        sequence_length (optional): A Tensor of shape `[batch_size]`.
            Time steps beyond the respective sequence lengths are
            counted into the entropy.
        average_across_timesteps (bool): If set, average the entropy across
            the time dimension. Must not set :attr:`average_across_timesteps`
            and :attr:`sum_over_timesteps` at the same time.
        average_across_batch (bool): If set, average the entropy across the
            batch dimension. Must not set :attr:`average_across_batch`'
            and :attr:`sum_over_batch` at the same time.
        average_across_remaining (bool): If set, average the entropy across the
            remaining dimensions. Must not set :attr:`average_across_remaining`'
            and :attr:`sum_over_remaining` at the same time.
            Used only when :attr:`logits` has rank >= 4.
        sum_over_timesteps (bool): If set, sum the entropy across the
            time dimension. Must not set :attr:`average_across_timesteps`
            and :attr:`sum_over_timesteps` at the same time.
        sum_over_batch (bool): If set, sum the entropy across the
            batch dimension. Must not set :attr:`average_across_batch`
            and :attr:`sum_over_batch` at the same time.
        sum_over_remaining (bool): If set, sum the entropy across the
            remaining dimension. Must not set :attr:`average_across_remaining`
            and :attr:`sum_over_remaining` at the same time.
            Used only when :attr:`logits` has rank >= 4.
        time_major (bool): The shape format of the inputs. If `True`,
            :attr:`logits` must have shape `[max_time, batch_size, ...]`.
            If `False` (default), it must have shape
            `[batch_size, max_time, ...]`.
    """
    entropy = _get_entropy(logits)

    if rank is None:
        rank = get_rank(logits)
    if rank is None:
        rank = 3
    rank -= 1 # reduced last dimension

    entropy = mask_and_reduce(
        entropy,
        sequence_length,
        rank=rank,
        average_across_batch=average_across_batch,
        average_across_timesteps=average_across_timesteps,
        average_across_remaining=average_across_remaining,
        sum_over_batch=sum_over_batch,
        sum_over_timesteps=sum_over_timesteps,
        sum_over_remaining=sum_over_remaining,
        time_major=time_major)

    return entropy
Esempio n. 8
0
def entropy_with_logits(logits,
                        rank=None,
                        average_across_batch=True,
                        average_across_remaining=False,
                        sum_over_batch=False,
                        sum_over_remaining=True):
    """Shannon entropy given logits.

    Args:
        logits: Unscaled log probabilities of shape
            `[batch_size, d_2, ..., d_{rank-1}, distribution_dim]`
            and of dtype `float32` or `float64`.

            The rank of the tensor is optionally specified by the argument
            :attr:`rank`.

            The tensor is considered as having `[batch_size, .., d_{rank-1}]`
            elements, each of which has a distribution of length `d_rank`
            (i.e., `distribution_dim`). So the last dimension is always
            summed out to compute the entropy.
        rank (int, optional): The rank of :attr:`logits`.
            If `None` (default), :attr:`rank` is inferred automatically from
            :attr:`logits`. If the inferred rank is `None`, :attr:`rank` is
            set to 2, i.e., assuming :attr:`logits` is of shape
            `[batch_size, distribution_dim]`
        average_across_batch (bool): If set, average the entropy across the
            batch dimension. Must not set :attr:`average_across_batch`'
            and :attr:`sum_over_batch` at the same time.
        average_across_remaining (bool): If set, average the entropy across the
            remaining dimensions. Must not set :attr:`average_across_remaining`'
            and :attr:`sum_over_remaining` at the same time.
            Used only when :attr:`logits` has rank >= 3.
        sum_over_batch (bool): If set, sum the entropy across the
            batch dimension. Must not set :attr:`average_across_batch`
            and :attr:`sum_over_batch` at the same time.
        sum_over_remaining (bool): If set, sum the entropy across the
            remaining dimension. Must not set :attr:`average_across_remaining`
            and :attr:`sum_over_remaining` at the same time.
            Used only when :attr:`logits` has rank >= 3.
    """
    entropy = _get_entropy(logits)

    if rank is None:
        rank = get_rank(logits)
    if rank is None:
        rank = 2
    rank -= 1 # reduced last dimension

    # Reduces
    if average_across_batch and sum_over_batch:
        raise ValueError("Only one of `average_across_batch` and "
                         "`sum_over_batch` can be set.")
    if average_across_remaining and sum_over_remaining:
        raise ValueError("Only one of `average_across_remaining` and "
                         "`sum_over_remaining` can be set.")
    sum_axes, average_axes = [], []
    if sum_over_batch:
        sum_axes.append(0)
    if average_across_batch:
        average_axes.append(0)
    if sum_over_remaining and rank >= 2:
        sum_axes += list(range(1, rank))
    if average_across_remaining and rank >= 2:
        average_axes += list(range(1, rank))

    entropy = reduce_dimensions(
        entropy, average_axes=average_axes, sum_axes=sum_axes)

    return entropy
def sequence_sigmoid_cross_entropy(labels,
                                   logits,
                                   sequence_length,
                                   average_across_batch=True,
                                   average_across_timesteps=False,
                                   average_across_classes=True,
                                   sum_over_batch=False,
                                   sum_over_timesteps=True,
                                   sum_over_classes=False,
                                   time_major=False,
                                   stop_gradient_to_label=False,
                                   name=None):
    """Computes sigmoid cross entropy for each time step of sequence
    predictions.

    Args:
        labels: Target class distributions.

            - If :attr:`time_major` is `False` (default), this must be a\
            Tensor of shape `[batch_size, max_time(, num_classes)]`.

            - If `time_major` is `True`, this must be a Tensor of shape\
            `[max_time, batch_size(, num_classes)]`.

            Each row of `labels` should be a valid probability
            distribution, otherwise, the computation of the gradient will be
            incorrect.
        logits: Unscaled log probabilities having the same shape as with
            :attr:`labels`.
        sequence_length: A Tensor of shape `[batch_size]`. Time steps beyond
            the respective sequence lengths will have zero losses.
        average_across_timesteps (bool): If set, average the loss across
            the time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        average_across_batch (bool): If set, average the loss across the
            batch dimension. Must not set `average_across_batch`'
            and `sum_over_batch` at the same time.
        average_across_classes (bool): If set, average the loss across the
            class dimension (if exists). Must not set
            `average_across_classes`' and `sum_over_classes` at
            the same time. Ignored if :attr:`logits` is a 2D Tensor.
        sum_over_timesteps (bool): If set, sum the loss across the
            time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        sum_over_batch (bool): If set, sum the loss across the
            batch dimension. Must not set `average_across_batch`
            and `sum_over_batch` at the same time.
        sum_over_classes (bool): If set, sum the loss across the
            class dimension. Must not set `average_across_classes`
            and `sum_over_classes` at the same time. Ignored if
            :attr:`logits` is a 2D Tensor.
        time_major (bool): The shape format of the inputs. If `True`,
            :attr:`labels` and :attr:`logits` must have shape
            `[max_time, batch_size, ...]`. If `False`
            (default), they must have shape `[batch_size, max_time, ...]`.
        stop_gradient_to_label (bool): If set, gradient propagation to
            :attr:`labels` will be disabled.
        name (str, optional): A name for the operation.

    Returns:
        A Tensor containing the loss, of rank 0, 1, or 2 depending on the
        arguments
        :attr:`{average_across}/{sum_over}_{timesteps}/{batch}/{classes}`.
        For example, if the class dimension does not exist, and

        - If :attr:`sum_over_timesteps` and :attr:`average_across_batch`  \
        are `True` (default), the return Tensor is of rank 0.

        - If :attr:`average_across_batch` is `True` and other arguments are \
        `False`, the return Tensor is of shape `[max_time]`.
    """

    with tf.name_scope(name, "sequence_sigmoid_cross_entropy"):
        if stop_gradient_to_label:
            labels = tf.stop_gradient(labels)

        losses = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
                                                         logits=logits)

        rank = shapes.get_rank(logits) or shapes.get_rank(labels)
        if rank is None:
            raise ValueError(
                'Cannot determine the rank of `logits` or `labels`.')

        losses = mask_and_reduce(
            losses,
            sequence_length,
            rank=rank,
            average_across_batch=average_across_batch,
            average_across_timesteps=average_across_timesteps,
            average_across_remaining=average_across_classes,
            sum_over_batch=sum_over_batch,
            sum_over_timesteps=sum_over_timesteps,
            sum_over_remaining=sum_over_classes,
            time_major=time_major)

        return losses
Esempio n. 10
0
def entropy_with_logits(logits: torch.Tensor,
                        rank: Optional[int] = None,
                        average_across_batch: bool = True,
                        average_across_remaining: bool = False,
                        sum_over_batch: bool = False,
                        sum_over_remaining: bool = True) -> torch.Tensor:
    r"""Shannon entropy given logits.

    Args:
        logits: Unscaled log probabilities of shape
            `[batch_size, d_2, ..., d_{rank-1}, distribution_dim]`
            and of dtype `float32` or `float64`.

            The rank of the tensor is optionally specified by the argument
            :attr:`rank`.

            The tensor is considered as having `[batch_size, .., d_{rank-1}]`
            elements, each of which has a distribution of length `d_rank`
            (i.e., `distribution_dim`). So the last dimension is always
            summed out to compute the entropy.
        rank (int, optional): The rank of :attr:`logits`.
            If `None` (default), `rank` is inferred automatically from
            `logits`. If the inference fails, `rank` is
            set to 2, i.e., assuming :attr:`logits` is of shape
            `[batch_size, distribution_dim]`
        average_across_batch (bool): If set, average the entropy across the
            batch dimension. Must not set `average_across_batch`'
            and `sum_over_batch` at the same time.
        average_across_remaining (bool): If set, average the entropy across the
            remaining dimensions. Must not set `average_across_remaining`'
            and `sum_over_remaining` at the same time.
            Used only when :attr:`logits` has rank >= 3.
        sum_over_batch (bool): If set, sum the entropy across the
            batch dimension. Must not set `average_across_batch`
            and `sum_over_batch` at the same time.
        sum_over_remaining (bool): If set, sum the entropy across the
            remaining dimension. Must not set `average_across_remaining`
            and `sum_over_remaining` at the same time.
            Used only when :attr:`logits` has rank >= 3.

    Returns:
        A Tensor containing the Shannon entropy. The dimensionality of the
        Tensor depends on the configuration of reduction arguments. For
        example, if both batch and remaining dimensions are reduced (by
        either sum or average), the returned Tensor is a scalar Tensor.
    """
    entropy = _get_entropy(logits)

    if rank is None:
        rank = get_rank(logits)
    if rank is None:
        rank = 2
    rank -= 1

    if average_across_batch and sum_over_batch:
        raise ValueError("Only one of `average_across_batch` and "
                         "`sum_over_batch` can be set.")
    if average_across_remaining and sum_over_remaining:
        raise ValueError("Only one of `average_across_remaining` and "
                         "`sum_over_remaining` can be set.")
    sum_axes, average_axes = [], []
    if sum_over_batch:
        sum_axes.append(0)
    if average_across_batch:
        average_axes.append(0)
    if sum_over_remaining and rank >= 2:
        sum_axes += list(range(1, rank))
    if average_across_remaining and rank >= 2:
        average_axes += list(range(1, rank))

    entropy = reduce_dimensions(entropy,
                                average_axes=average_axes,
                                sum_axes=sum_axes)

    return entropy
Esempio n. 11
0
def sequence_sigmoid_cross_entropy(
        labels: torch.Tensor,
        logits: torch.Tensor,
        sequence_length: Optional[torch.LongTensor],
        average_across_batch: bool = True,
        average_across_timesteps: bool = False,
        average_across_classes: bool = True,
        sum_over_batch: bool = False,
        sum_over_timesteps: bool = True,
        sum_over_classes: bool = False,
        time_major: bool = False,
        stop_gradient_to_label: bool = False) -> torch.Tensor:
    r"""Computes sigmoid cross entropy for each time step of sequence
    predictions.

    Args:
        labels: Target class distributions.

            - If :attr:`time_major` is `False` (default), this must be a
              Tensor of shape `[batch_size, max_time(, num_classes)]`.

            - If `time_major` is `True`, this must be a Tensor of shape
              `[max_time, batch_size(, num_classes)]`.

            Each row of `labels` should be a valid probability
            distribution, otherwise, the computation of the gradient will be
            incorrect.
        logits: Unscaled log probabilities having the same shape as with
            :attr:`labels`.
        sequence_length: A Tensor of shape `[batch_size]`. Time steps beyond
            the respective sequence lengths will have zero losses.
        average_across_timesteps (bool): If set, average the loss across
            the time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        average_across_batch (bool): If set, average the loss across the
            batch dimension. Must not set `average_across_batch`'
            and `sum_over_batch` at the same time.
        average_across_classes (bool): If set, average the loss across the
            class dimension (if exists). Must not set
            `average_across_classes`' and `sum_over_classes` at
            the same time. Ignored if :attr:`logits` is a 2D Tensor.
        sum_over_timesteps (bool): If set, sum the loss across the
            time dimension. Must not set `average_across_timesteps`
            and `sum_over_timesteps` at the same time.
        sum_over_batch (bool): If set, sum the loss across the
            batch dimension. Must not set `average_across_batch`
            and `sum_over_batch` at the same time.
        sum_over_classes (bool): If set, sum the loss across the
            class dimension. Must not set `average_across_classes`
            and `sum_over_classes` at the same time. Ignored if
            :attr:`logits` is a 2D Tensor.
        time_major (bool): The shape format of the inputs. If `True`,
            :attr:`labels` and :attr:`logits` must have shape
            `[max_time, batch_size, ...]`. If `False`
            (default), they must have shape `[batch_size, max_time, ...]`.
        stop_gradient_to_label (bool): If set, gradient propagation to
            :attr:`labels` will be disabled.

    Returns:
        A Tensor containing the loss, of rank 0, 1, or 2 depending on the
        arguments
        :attr:`{average_across}/{sum_over}_{timesteps}/{batch}/{classes}`.
        For example, if the class dimension does not exist, and

        - If :attr:`sum_over_timesteps` and :attr:`average_across_batch`
          are `True` (default), the return Tensor is of rank 0.

        - If :attr:`average_across_batch` is `True` and other arguments are
          `False`, the return Tensor is of shape `[max_time]`.
    """
    if stop_gradient_to_label:
        labels = labels.detach()
    losses = F.binary_cross_entropy_with_logits(logits,
                                                labels.type(logits.dtype),
                                                reduction='none')

    rank = shapes.get_rank(logits) or shapes.get_rank(labels)

    losses = mask_and_reduce(losses,
                             sequence_length,
                             rank=rank,
                             average_across_batch=average_across_batch,
                             average_across_timesteps=average_across_timesteps,
                             average_across_remaining=average_across_classes,
                             sum_over_batch=sum_over_batch,
                             sum_over_timesteps=sum_over_timesteps,
                             sum_over_remaining=sum_over_classes,
                             time_major=time_major)

    return losses