Exemple #1
0
    def forward(
            self,  # type: ignore
            input: torch.Tensor,
            sequence_length: Union[torch.LongTensor, List[int]] = None,
            dtype: Optional[torch.dtype] = None) -> torch.Tensor:
        r"""Feeds forward inputs through the network layers and returns outputs.

        Args:
            input: The inputs to the network, which is a 3D tensor.
            sequence_length (optional): An :tensor:`LongTensor` of shape
                ``[batch_size]`` or a python array containing the length of
                each element in :attr:`inputs`. If given, time steps beyond
                the length will first be masked out before feeding to the
                layers.
            dtype (optional): Type of the inputs. If not provided,
                infers from inputs automatically.
        Returns:
            The output of the final layer.
        """
        if sequence_length is not None:
            input = mask_sequences(input,
                                   sequence_length,
                                   dtype=dtype,
                                   time_major=False)
        return super().forward(input)
Exemple #2
0
    def test_mask_sequences(self):
        r"""Tests :func:`texar.torch.utils.shapes.mask_sequences`.
        """
        seq = torch.ones(3, 4, 3, dtype=torch.int32)
        seq_length = torch.tensor([3, 2, 1], dtype=torch.int32)

        masked_seq = shapes.mask_sequences(seq, seq_length)
        np.testing.assert_array_equal(masked_seq.shape, seq.shape)
        seq_sum = torch.sum(masked_seq, dim=(1, 2))
        np.testing.assert_array_equal(seq_sum, seq_length * 3)
    def forward(self,  # type: ignore
                positions: Optional[torch.LongTensor] = None,
                sequence_length: Optional[torch.LongTensor] = None, **kwargs) \
            -> torch.Tensor:
        r"""Embeds.
        Either :attr:`positions` or :attr:`sequence_length` is required:

        - If both are given, :attr:`sequence_length` is used to mask out
          embeddings of those time steps beyond the respective sequence
          lengths.
        - If only :attr:`sequence_length` is given, then positions
          from `0` to `sequence_length - 1` are embedded.

        Args:
            positions (optional): An :tensor:`LongTensor` containing the
                position IDs to embed.
            sequence_length (optional): An :tensor:`LongTensor` of shape
                ``[batch_size]``. Time steps beyond
                the respective sequence lengths will have zero-valued
                embeddings.

        Returns:
            A Tensor of shape ``[batch_size, position_size, dim]``.
        """
        if positions is None:
            if sequence_length is None:
                raise ValueError(
                    'Either `positions` or `sequence_length` is required.')
            max_length = sequence_length.max()
            batch_size = sequence_length.size(0)
            inputs = torch.arange(max_length).to(device=sequence_length.device)
            inputs = inputs.expand(batch_size, max_length)
        else:
            inputs = positions

        if self._cache_embeddings:
            outputs = F.embedding(inputs, self.signal, **kwargs)
        else:
            outputs = self._compute_embeddings(inputs, self.inv_timescales)

        if sequence_length is not None:
            outputs = mask_sequences(outputs, sequence_length)

        return outputs
Exemple #4
0
def _forward_output_layers(
        inputs: torch.Tensor,
        output_layer: Optional[nn.Module],
        time_major: bool,
        sequence_length: Optional[Union[torch.LongTensor, List[int]]] = None) \
        -> Tuple[torch.Tensor, int]:
    r"""Forwards inputs through the output layers.

    Args:
        inputs: A Tensor of shape ``[batch_size, max_time] + input_size`` if
            :attr:`time_major` is `False`, or shape
            ``[max_time, batch_size] + input_size`` if :attr:`time_major` is
            `True`.
        output_layer (optional): :torch_nn:`Sequential` or :torch_nn:`Module`
            of output layers.
        time_major (bool): The shape format of the :attr:`inputs` and
            :attr:`outputs` Tensors. If `True`, these tensors are of shape
            `[max_time, batch_size, input_size]`. If `False` (default),
            these tensors are of shape `[batch_size, max_time, input_size]`.
        sequence_length (optional): A 1D :tensor:`LongTensor` of shape
            ``[batch_size]``. Sequence lengths of the batch inputs. Used to
            copy-through state and zero-out outputs when past a batch element's
            sequence length.

    Returns:
        A pair :attr:`(outputs, outputs_size), where

        - :attr:`outputs`: A Tensor of shape
        `[batch_size, max_time] + outputs_size`.

        - :attr:`outputs_size`: An `int` representing the output size.
    """
    if output_layer is None:
        return inputs, inputs.shape[-1]

    output = output_layer(inputs)

    if sequence_length is not None:
        output = mask_sequences(output, sequence_length, time_major=time_major)

    output_size = output.shape[-1]

    return output, output_size
Exemple #5
0
def _discount_reward_tensor_1d(reward: torch.Tensor,
                               sequence_length: Optional[torch.LongTensor],
                               discount: float = 1.) -> torch.Tensor:
    r"""Computes discounted reward.

    Args:
        reward: 1D Tensor with shape `[batch_size]`.
        sequence_length: A Tensor of shape `[batch_size]`.
        Time steps beyond the respective sequence lengths will be masked.
        discount (float): A scalar. The discount factor.

    Returns:
        A 2D Tensor of the discounted reward.
    """
    if sequence_length is None:
        raise ValueError('sequence_length must not be `None` for 1D reward.')

    if not isinstance(sequence_length, torch.Tensor):
        sequence_length = torch.tensor(sequence_length,
                                       dtype=torch.int64,
                                       device=reward.device)

    batch_size = reward.shape[0]
    max_seq_length = torch.max(sequence_length)
    dtype: torch.dtype = reward.dtype

    if discount == 1.:
        disc_reward = reward.unsqueeze(-1).expand(batch_size, max_seq_length)
    else:
        mask = sequence_mask(sequence_length, dtype=dtype)
        mask = torch.cat((mask[:, 1:], torch.zeros_like(mask[:, -1:])), dim=1)
        # Make each row = [discount, ..., discount, 1, ..., 1]
        dmat = mask * discount + (1 - mask)
        dmat = torch.flip(dmat, (1, ))
        dmat = torch.cumprod(dmat, dim=1)
        dmat = torch.flip(dmat, (1, ))
        disc_reward = dmat * reward.unsqueeze(-1)

    disc_reward = mask_sequences(disc_reward, sequence_length, dtype=dtype)

    return disc_reward
Exemple #6
0
def _discount_reward_tensor_2d(reward: torch.Tensor,
                               sequence_length: Optional[
                                   torch.LongTensor] = None,
                               discount: float = 1.) -> torch.Tensor:
    r"""Computes discounted reward.

    Args:
        reward: 2D Tensor with shape `[batch_size, max_time]`.
        sequence_length (optional): A Tensor of shape `[batch_size]`.
            Time steps beyond the respective sequence lengths will be masked.
        discount (float): A scalar. The discount factor.

    Returns:
        A 2D Tensor of the discounted reward.
    """
    dtype: torch.dtype = reward.dtype
    if sequence_length is not None:
        reward = mask_sequences(reward, sequence_length, dtype=dtype)

    if discount == 1.:
        reward = torch.flip(reward, (1, ))
        disc_reward = torch.cumsum(reward, dim=1)
        disc_reward = torch.flip(disc_reward, (1, ))
    else:
        # [max_time, batch_size]
        rev_reward_T = torch.flip(reward, (1, )).permute(1, 0)

        res = []
        acc = torch.zeros_like(reward[:, 1])
        for i in range(rev_reward_T.shape[0]):
            cur = rev_reward_T[i]
            acc = cur + discount * acc
            res.append(acc)

        rev_reward_T_cum = torch.stack(res, dim=0)
        disc_reward = torch.flip(rev_reward_T_cum.permute(1, 0), (1, ))

    return disc_reward
Exemple #7
0
    def forward(self,  # type: ignore
                input: torch.Tensor,
                sequence_length: Optional[Union[torch.LongTensor,
                                                List[int]]] = None,
                dtype: Optional[torch.dtype] = None,
                data_format: Optional[str] = None) -> torch.Tensor:
        r"""Feeds forward inputs through the network layers and returns outputs.

        Args:
            input: The inputs to the network, which is a 3D tensor.
            sequence_length (optional): An :tensor:`LongTensor` of shape
                ``[batch_size]`` or a python array containing the length of
                each element in :attr:`inputs`. If given, time steps beyond
                the length will first be masked out before feeding to the
                layers.
            dtype (optional): Type of the inputs. If not provided,
                infers from inputs automatically.
            data_format (optional): Data type of the input tensor. If
                ``channels_last``, the last dimension will be treated as channel
                dimension so the size of the :attr:`input` should be
                `[batch_size, X, channel]`. If ``channels_first``, first
                dimension will be treated as channel dimension so the size
                should be `[batch_size, channel, X]`. Defaults to None.
                If None, the value will be picked from hyperparameters.
        Returns:
            The output of the final layer.
        """
        if input.dim() != 3:
            raise ValueError("'input' should be a 3D tensor.")

        if data_format is None:
            data_format = self.hparams["data_format"]

        if data_format == "channels_first":
            # masking requires channels in last dimension
            input = input.permute(0, 2, 1)

            if sequence_length is not None:
                input = mask_sequences(input, sequence_length,
                                       dtype=dtype, time_major=False)

            # network is constructed for channel first tensors
            input = input.permute(0, 2, 1)

            output = super().forward(input)

        elif data_format == "channels_last":
            if sequence_length is not None:
                input = mask_sequences(input, sequence_length,
                                       dtype=dtype, time_major=False)

            input = input.permute(0, 2, 1)

            output = super().forward(input)

            # transpose only when tensors are 3D
            if output.dim() == 3:
                output = output.permute(0, 2, 1)
        else:
            raise ValueError("Invalid 'data_format'")

        return output
    def forward(self,  # type: ignore
                inputs: Optional[torch.Tensor] = None,
                sequence_length: Optional[torch.LongTensor] = None,
                memory: Optional[torch.Tensor] = None,
                memory_sequence_length: Optional[torch.LongTensor] = None,
                memory_attention_bias: Optional[torch.Tensor] = None,
                context: Optional[torch.Tensor] = None,
                context_sequence_length: Optional[torch.LongTensor] = None,
                helper: Optional[Helper] = None,
                decoding_strategy: str = 'train_greedy',
                max_decoding_length: Optional[int] = None,
                impute_finished: bool = False,
                infer_mode: Optional[bool] = None,
                beam_width: Optional[int] = None,
                length_penalty: float = 0.,
                **kwargs) \
            -> Union[
                TransformerDecoderOutput,
                Tuple[TransformerDecoderOutput, torch.LongTensor],
                Dict[str, torch.Tensor]]:
        r"""Performs decoding.

        The interface is very similar to that of RNN decoders
        (:class:`~texar.torch.modules.RNNDecoderBase`). In particular,
        the function provides **3 ways** to specify the decoding method, with
        varying flexibility:

        1. The :attr:`decoding_strategy` argument.

           - **"train_greedy"**: decoding in teacher-forcing fashion (i.e.,
             feeding ground truth to decode the next step), and for each step
             sample is obtained by taking the `argmax` of logits.
             Argument :attr:`inputs` is required for this strategy.
             :attr:`sequence_length` is optional.
           - **"infer_greedy"**: decoding in inference fashion (i.e., feeding
             `generated` sample to decode the next step), and for each step
             sample is obtained by taking the `argmax` of logits.
             Arguments :attr:`(start_tokens, end_token)` are
             required for this strategy, and argument
             :attr:`max_decoding_length` is optional.
           - **"infer_sample"**: decoding in inference fashion, and for each
             step sample is obtained by `random sampling` from the logits.
             Arguments :attr:`(start_tokens, end_token)` are required for this
             strategy, and argument :attr:`max_decoding_length` is optional.

          This argument is used only when arguments :attr:`helper` and
          :attr:`beam_width` are both `None`.

        2. The :attr:`helper` argument: An instance of subclass of
           :class:`~texar.torch.modules.Helper`.
           This provides a superset of decoding strategies than above.
           The interface is the same as in RNN decoders.
           Please refer to :meth:`texar.torch.modules.RNNDecoderBase.forward`
           for detailed usage and examples.

           Note that, here, though using a
           :class:`~texar.torch.modules.TrainingHelper` corresponding to the
           ``"train_greedy"`` strategy above, the implementation is *slower*
           than directly setting ``decoding_strategy="train_greedy"`` (though
           output results are the same).

           Argument :attr:`max_decoding_length` is optional.

        3. **Beam search**: set :attr:`beam_width` to use beam search decoding.
           Arguments :attr:`(start_tokens, end_token)` are required,
           and argument :attr:`max_decoding_length` is optional.

        Args:
            memory (optional): The memory to attend, e.g., the output of an RNN
                encoder. A :tensor:`Tensor` of shape
                ``[batch_size, memory_max_time, dim]``.
            memory_sequence_length (optional): A :tensor:`Tensor` of shape
                ``[batch_size]`` containing the sequence lengths for the batch
                entries in memory. Used to create attention bias of
                :attr:`memory_attention_bias` is not given. Ignored if
                :attr:`memory_attention_bias` is provided.
            memory_attention_bias (optional): A :tensor:`Tensor` of shape
                ``[batch_size, num_heads, memory_max_time, dim]``.
                An attention bias typically sets the value of a padding
                position to a large negative value for masking. If not given,
                :attr:`memory_sequence_length` is used to automatically
                create an attention bias.
            inputs (optional): Input tensors for teacher forcing decoding.
                Used when :attr:`decoding_strategy` is set to
                ``"train_greedy"``, or when `hparams`-configured helper is used.

                The attr:`inputs` is a :tensor:`LongTensor` used as index to
                look up embeddings and feed in the decoder. For example, if
                :attr:`embedder` is an instance of
                :class:`~texar.torch.modules.WordEmbedder`, then :attr:`inputs`
                is usually a 2D int Tensor `[batch_size, max_time]` (or
                `[max_time, batch_size]` if `input_time_major` == `True`)
                containing the token indexes.
            sequence_length (optional): A :tensor:`LongTensor` of shape
                ``[batch_size]``, containing the sequence length of
                :attr:`inputs`. Tokens beyond the respective sequence length are
                masked out.
                Used when :attr:`decoding_strategy` is set to
                ``"train_greedy"``.
            decoding_strategy (str): A string specifying the decoding
                strategy, including ``"train_greedy"``, ``"infer_greedy"``,
                ``"infer_sample"``.
                Different arguments are required based on the
                strategy. See above for details. Ignored if
                :attr:`beam_width` or :attr:`helper` is set.
            beam_width (int): Set to use beam search. If given,
                :attr:`decoding_strategy` is ignored.
            length_penalty (float): Length penalty coefficient used in beam
                search decoding. Refer to https://arxiv.org/abs/1609.08144
                for more details.
                It should be larger if longer sentences are desired.
            context (optional): An :tensor:`LongTensor` of shape
                ``[batch_size, length]``, containing the starting tokens for
                decoding. If context is set, ``start_tokens`` of the
                :class:`~texar.torch.modules.Helper` will be ignored.
            context_sequence_length (optional): Specify the length of context.
            max_decoding_length (int, optional): The maximum allowed number of
                decoding steps.
                If `None` (default), use ``"max_decoding_length"`` defined in
                :attr:`hparams`. Ignored in ``"train_greedy"`` decoding.
            impute_finished (bool): If `True`, then states for batch
                entries which are marked as finished get copied through and
                the corresponding outputs get zeroed out.  This causes some
                slowdown at each time step, but ensures that the final state
                and outputs have the correct values and that backprop ignores
                time steps that were marked as finished. Ignored in
                ``"train_greedy"`` decoding.
            helper (optional): An instance of
                :class:`~texar.torch.modules.Helper`
                that defines the decoding strategy. If given,
                ``decoding_strategy`` and helper configurations in
                :attr:`hparams` are ignored.
            infer_mode (optional): If not `None`, overrides mode given by
                :attr:`self.training`.

        Returns:

            - For **"train_greedy"** decoding, returns an instance of
              :class:`~texar.torch.modules.TransformerDecoderOutput` which
              contains `sample_id` and `logits`.

            - For **"infer_greedy"** and **"infer_sample"** decoding or
              decoding with :attr:`helper`, returns
              a tuple ``(outputs, sequence_lengths)``, where ``outputs`` is an
              instance of :class:`~texar.torch.modules.TransformerDecoderOutput`
              as in `"train_greedy"`, and ``sequence_lengths`` is a
              :tensor:`LongTensor` of shape ``[batch_size]`` containing the
              length of each sample.

            - For **beam search** decoding, returns a ``dict`` containing keys
              ``"sample_id"`` and ``"log_prob"``.

                - ``"sample_id"`` is a :tensor:`LongTensor` of shape
                  ``[batch_size, max_time, beam_width]`` containing generated
                  token indexes. ``sample_id[:,:,0]`` is the highest-probable
                  sample.
                - ``"log_prob"`` is a :tensor:`Tensor` of shape
                  ``[batch_size, beam_width]`` containing the log probability
                  of each sequence sample.
        """

        if memory is not None:
            if memory_attention_bias is None:
                if memory_sequence_length is None:
                    raise ValueError("`memory_sequence_length` is required if "
                                     "`memory_attention_bias` is not given.")

                enc_padding = 1 - sequence_mask(memory_sequence_length,
                                                memory.size(1),
                                                dtype=torch.float32)
                memory_attention_bias = attn.attention_bias_ignore_padding(
                    enc_padding)

        # record the context, which will be used in step function
        # for dynamic_decode
        if context is not None:
            if context_sequence_length is None:
                raise ValueError("'context_sequence_length' must not be None"
                                 "when 'context' is specified.")
            self._state_context = context[:, 1:]
            self._state_context_sequence_length = context_sequence_length - 1
        else:
            self._state_context = None
            self._state_context_sequence_length = None

        # Faster code path for teacher-forcing training
        if (helper is None and beam_width is None
                and decoding_strategy == 'train_greedy'):
            if inputs is None:
                raise ValueError(
                    "'input' must not be none "
                    "when using 'train_greedy' decoding strategy.")
            times = torch.arange(inputs.size(1),
                                 dtype=torch.long,
                                 device=inputs.device)
            times = times.unsqueeze(0).expand(inputs.size(0), -1)
            inputs = self.embed_tokens(inputs, times)
            if sequence_length is not None:
                inputs = mask_sequences(inputs, sequence_length)

            decoder_self_attention_bias = (attn.attention_bias_lower_triangle(
                inputs.size(1)))

            decoder_output = self._self_attention_stack(
                inputs,
                memory,
                decoder_self_attention_bias,
                memory_attention_bias,
                cache=None)
            logits = self._output_layer(decoder_output)
            sample_id = torch.argmax(logits, dim=-1)

            return TransformerDecoderOutput(logits, sample_id)

        # Inference code path.
        if max_decoding_length is None:
            max_decoding_length = self._hparams.max_decoding_length

        self._state_max_decoding_length = max_decoding_length

        if beam_width is None or beam_width == 1:  # Inference-like decoding
            # Prepare helper
            if helper is None:
                kwargs.update(decoding_strategy=decoding_strategy)
                if context is not None:
                    kwargs.update(start_tokens=context[:, 0])
                helper = self._create_or_get_helper(infer_mode, **kwargs)
            assert isinstance(helper, EmbeddingHelper)

            self._state_cache = self._init_cache(memory,
                                                 memory_attention_bias,
                                                 beam_search_decoding=False,
                                                 batch_size=helper.batch_size)
            if context is not None:
                assert self._state_context is not None
                pad_length = max_decoding_length - self._state_context.size(1)
                if pad_length > 0:
                    self._state_context = torch.cat(
                        (self._state_context,
                         self._state_context.new_zeros(
                             self._state_context.size(0), pad_length)),
                        dim=1)

            outputs, cache, sequence_lengths = self.dynamic_decode(
                helper,
                inputs=None,
                sequence_length=None,
                initial_state=None,
                max_decoding_length=max_decoding_length,
                impute_finished=impute_finished)
            del cache  # not used

            if context is not None:
                # Here the length of sample_id will be larger than that
                # of logit by 1, because there will be a additional
                # start_token in the returned sample_id.
                # the start_id should be the first token of the
                # given context
                start_tokens = context[:, 0]
                outputs = TransformerDecoderOutput(
                    logits=outputs.logits,
                    sample_id=torch.cat(
                        [start_tokens.unsqueeze(1), outputs.sample_id], dim=1))
                sequence_lengths = sequence_lengths + 1

            return outputs, sequence_lengths

        else:  # Beam-search decoding
            # Ignore `decoding_strategy` and # assume `helper` is not set.
            if helper is not None:
                raise ValueError("Must not set 'beam_width' and 'helper' "
                                 "simultaneously.")
            if context is not None:
                start_tokens = context[:, 0]
            else:
                if 'start_tokens' not in kwargs:
                    raise ValueError(
                        "'start_tokens' must be specified when using"
                        "beam search decoding.")
                start_tokens = kwargs['start_tokens']
            _batch_size = start_tokens.size(0)
            self._state_cache = self._init_cache(memory,
                                                 memory_attention_bias,
                                                 beam_search_decoding=True,
                                                 batch_size=_batch_size)
            end_token: int = kwargs.get('end_token')  # type: ignore

            # The output format is different when running beam search.
            sample_id, log_prob = self.beam_decode(
                start_tokens,
                end_token,
                embedding_fn=self.embed_tokens,
                beam_width=beam_width,
                length_penalty=length_penalty,
                decode_length=max_decoding_length)

            return {'sample_id': sample_id, 'log_prob': log_prob}
Exemple #9
0
def _dynamic_rnn_loop(cell: RNNCellBase[State],
                      inputs: torch.Tensor,
                      initial_state: State,
                      sequence_length: torch.LongTensor) \
        -> Tuple[torch.Tensor, State]:
    r"""Internal implementation of Dynamic RNN.

    Args:
        cell: An instance of RNNCell.
        inputs: A ``Tensor`` of shape ``[time, batch_size, input_size]``,
            or a nested tuple of such elements.
        initial_state: A ``Tensor`` of shape ``[batch_size, state_size]``,
            or if ``cell.state_size`` is a tuple, then this should be a tuple
            of tensors having shapes ``[batch_size, s]`` for ``s`` in
            ``cell.state_size``.
        sequence_length: (optional) An ``int32`` ``Tensor``
            of shape ``[batch_size]``.

    Returns:
        Tuple ``(final_outputs, final_state)``.
        final_outputs:
            A ``Tensor`` of shape ``[time, batch_size, cell.output_size]``. If
            ``cell.output_size`` is a (possibly nested) tuple of ints or
            ``torch.Size`` objects, then this returns a
            (possibly nested) tuple of Tensors matching the corresponding
            shapes.
        final_state:
            A ``Tensor``, or possibly nested tuple of Tensors, matching
            in length and shapes to ``initial_state``.
    """
    state = initial_state
    time_steps = inputs.shape[0]
    all_outputs = []

    all_state = map_structure(lambda _: no_map(list), state)

    for i in range(time_steps):
        output, state = cell(inputs[i], state)
        all_outputs.append(output)
        map_structure_zip(lambda xs, x: xs.append(x), (all_state, state))
    # TODO: Do not compute everything regardless of sequence_length

    final_outputs = torch.stack(all_outputs, dim=0)
    final_outputs = mask_sequences(final_outputs,
                                   sequence_length=sequence_length,
                                   time_major=True)

    final_state = map_structure(lambda _: no_map(list), state)
    # pylint: disable=cell-var-from-loop
    # Our use case is fine because the function is called immediately and
    # exclusively in the current iteration of the loop.
    for batch_idx, time_idx in enumerate(sequence_length.tolist()):
        if time_idx > 0:
            map_structure_zip(
                lambda xs, x: xs.append(x[time_idx - 1][batch_idx]),
                (final_state, all_state))
        else:
            map_structure_zip(lambda xs, x: xs.append(x[batch_idx]),
                              (final_state, initial_state))
    # pylint: enable=cell-var-from-loop

    final_state = map_structure(lambda x: torch.stack(x, dim=0), final_state)

    return final_outputs, final_state
    def forward(
            self,  # type: ignore
            positions: Optional[torch.LongTensor] = None,
            sequence_length: Optional[torch.LongTensor] = None,
            **kwargs):
        r"""Embeds the positions.

        Either :attr:`positions` or :attr:`sequence_length` is required:

            - If both are given, :attr:`sequence_length` is used to mask out
              embeddings of those time steps beyond the respective sequence
              lengths.
            - If only :attr:`sequence_length` is given, then positions
              from 0 to ``sequence_length - 1`` are embedded.

        Args:
            positions (optional): A :tensor:`LongTensor` containing the position
                IDs to embed.
            sequence_length (optional): An :tensor:`LongTensor` of shape
                ``[batch_size]``. Time steps beyond the respective sequence
                lengths will have zero-valued embeddings.
            kwargs: Additional keyword arguments for
                :torch_nn:`functional.embedding` besides
                :attr:`params` and :attr:`ids`.

        Returns:
            A `Tensor` of shape `shape(inputs) + embedding dimension`.
        """
        # Gets embedder inputs
        if positions is None:
            if sequence_length is None:
                raise ValueError(
                    'Either `positions` or `sequence_length` is required.')
            max_length = torch.max(sequence_length)
            single_inputs = torch.arange(start=0, end=max_length)
            # Expands `single_inputs` to have shape [batch_size, max_length]
            inputs = single_inputs.unsqueeze(0)
            inputs = inputs.expand(len(sequence_length), -1).contiguous()
        else:
            inputs = positions

        ids_rank = inputs.dim()
        embedding = self._embedding
        inputs = inputs.to(device=embedding.device)
        # Gets dropout strategy
        st = self._hparams.dropout_strategy

        # Dropouts as 'item_type' before embedding
        if st == 'item_type':
            noise_shape = self._get_noise_shape(dropout_strategy=st,
                                                dropout_input=embedding)
            embedding = self._dropout_layer(embedding, noise_shape)

        # Embeds
        outputs = torch.nn.functional.embedding(inputs.type(torch.long),
                                                embedding, **kwargs)

        # Dropouts as 'item' or 'elements' after embedding
        if st != 'item_type':
            noise_shape = self._get_noise_shape(dropout_strategy=st,
                                                dropout_input=outputs,
                                                ids_rank=ids_rank)
            outputs = self._dropout_layer(outputs, noise_shape)

        # Optionally masks
        if sequence_length is not None:
            outputs = mask_sequences(outputs, sequence_length)

        return outputs