Example #1
0
    def test_dynamic_rnn_initial_state(self):
        r"""Tests :meth:`~texar.utils.rnn.dynamic_rnn`.
        """
        inputs = torch.rand(self._batch_size, self._max_time, self._input_size)

        rnn_initial_state = torch.rand(self._batch_size, self._hidden_size)
        lstm_initial_state = (torch.rand(self._batch_size, self._hidden_size),
                              torch.rand(self._batch_size, self._hidden_size))

        # RNN
        outputs, final_state = dynamic_rnn(self._rnn,
                                           inputs,
                                           sequence_length=None,
                                           initial_state=rnn_initial_state,
                                           time_major=False)

        self.assertEqual(
            outputs.shape,
            torch.Size([self._batch_size, self._max_time, self._hidden_size]))
        self.assertEqual(final_state.shape,
                         torch.Size([self._batch_size, self._hidden_size]))

        # LSTM
        outputs, final_state = dynamic_rnn(self._lstm,
                                           inputs,
                                           sequence_length=None,
                                           initial_state=lstm_initial_state,
                                           time_major=False)

        self.assertEqual(
            outputs.shape,
            torch.Size([self._batch_size, self._max_time, self._hidden_size]))
        self.assertIsInstance(final_state, tuple)
        self.assertEqual(final_state[0].shape,
                         torch.Size([self._batch_size, self._hidden_size]))
        self.assertEqual(final_state[1].shape,
                         torch.Size([self._batch_size, self._hidden_size]))

        # GRU
        outputs, final_state = dynamic_rnn(self._gru,
                                           inputs,
                                           sequence_length=None,
                                           initial_state=rnn_initial_state,
                                           time_major=False)

        self.assertEqual(
            outputs.shape,
            torch.Size([self._batch_size, self._max_time, self._hidden_size]))
        self.assertEqual(final_state.shape,
                         torch.Size([self._batch_size, self._hidden_size]))
Example #2
0
    def forward(
            self,  # type: ignore
            inputs: torch.Tensor,
            sequence_length: Optional[Union[torch.LongTensor,
                                            List[int]]] = None,
            initial_state: Optional[State] = None,
            time_major: bool = False,
            return_cell_output: bool = False,
            return_output_size: bool = False):
        r"""Encodes the inputs.

        Args:
            inputs: A 3D Tensor of shape ``[batch_size, max_time, dim]``.
                The first two dimensions
                :attr:`batch_size` and :attr:`max_time` are exchanged if
                :attr:`time_major` is `True`.
            sequence_length (optional): A 1D :tensor:`LongTensor` of shape
                ``[batch_size]``.
                Sequence lengths of the batch inputs. Used to copy-through
                state and zero-out outputs when past a batch element's sequence
                length.
            initial_state (optional): Initial state of the RNN.
            time_major (bool): The shape format of the :attr:`inputs` and
                :attr:`outputs` Tensors. If `True`, these tensors are of shape
                ``[max_time, batch_size, depth]``. If `False` (default),
                these tensors are of shape ``[batch_size, max_time, depth]``.
            return_cell_output (bool): Whether to return the output of the RNN
                cell. This is the results prior to the output layer.
            return_output_size (bool): Whether to return the size of the
                output (i.e., the results after output layers).

        Returns:
            - By default (both ``return_cell_output`` and ``return_output_size``
              are `False`), returns a pair :attr:`(outputs, final_state)`,
              where

              - :attr:`outputs`: The RNN output tensor by the output layer
                (if exists) or the RNN cell (otherwise). The tensor is of
                shape ``[batch_size, max_time, output_size]`` if
                ``time_major`` is `False`, or
                ``[max_time, batch_size, output_size]`` if
                ``time_major`` is `True`.
                If RNN cell output is a (nested) tuple of Tensors, then the
                :attr:`outputs` will be a (nested) tuple having the same
                nest structure as the cell output.

              - :attr:`final_state`: The final state of the RNN, which is a
                Tensor of shape ``[batch_size] + cell.state_size`` or
                a (nested) tuple of Tensors if ``cell.state_size`` is a
                (nested) tuple.

            - If ``return_cell_output`` is True, returns a triple
              :attr:`(outputs, final_state, cell_outputs)`

              - :attr:`cell_outputs`: The outputs by the RNN cell prior to the
                output layer, having the same structure with :attr:`outputs`
                except for the ``output_dim``.

            - If ``return_output_size`` is `True`, returns a tuple
              :attr:`(outputs, final_state, output_size)`

              - :attr:`output_size`: A (possibly nested tuple of) int
                representing the size of :attr:`outputs`. If a single int or
                an int array, then ``outputs`` has shape
                ``[batch/time, time/batch] + output_size``. If
                a (nested) tuple, then ``output_size`` has the same
                structure as with ``outputs``.

            - If both ``return_cell_output`` and ``return_output_size`` are
              `True`, returns
              :attr:`(outputs, final_state, cell_outputs, output_size)`.
        """

        cell_outputs, state = dynamic_rnn(cell=self._cell,
                                          inputs=inputs,
                                          sequence_length=sequence_length,
                                          initial_state=initial_state,
                                          time_major=time_major)

        outputs, output_size = _forward_output_layers(
            inputs=cell_outputs,
            output_layer=self._output_layer,
            time_major=time_major,
            sequence_length=sequence_length)

        rets = (outputs, state)
        if return_cell_output:
            rets += (cell_outputs, )  # type: ignore
        if return_output_size:
            rets += (output_size, )  # type: ignore
        return rets