Esempio n. 1
0
    def forward(
            self,  # type: ignore
            inputs: torch.Tensor,
            sequence_length: Optional[Union[torch.LongTensor,
                                            List[int]]] = None,
            initial_state_fw: Optional[State] = None,
            initial_state_bw: Optional[State] = None,
            time_major: bool = False,
            return_cell_output: bool = False,
            return_output_size: bool = False):
        r"""Encodes the inputs.

        Args:
            inputs: A 3D Tensor of shape ``[batch_size, max_time, dim]``.
                The first two dimensions
                ``batch_size`` and ``max_time`` may be exchanged if
                ``time_major`` is `True`.
            sequence_length (optional): A 1D :tensor:`LongTensor` of shape
                ``[batch_size]``.
                Sequence lengths of the batch inputs. Used to copy-through
                state and zero-out outputs when past a batch element's sequence
                length.
            initial_state_fw: (optional): Initial state of the forward RNN.
            initial_state_bw: (optional): Initial state of the backward RNN.
            time_major (bool): The shape format of the :attr:`inputs` and
                :attr:`outputs` Tensors. If `True`, these tensors are of shape
                ``[max_time, batch_size, depth]``. If `False` (default),
                these tensors are of shape ``[batch_size, max_time, depth]``.
            return_cell_output (bool): Whether to return the output of the RNN
                cell. This is the results prior to the output layer.
            return_output_size (bool): Whether to return the output size of the
                RNN cell. This is the results after the output layer.

        Returns:
            - By default (both ``return_cell_output`` and ``return_output_size``
              are `False`), returns a pair :attr:`(outputs, final_state)`

              - :attr:`outputs`: A tuple ``(outputs_fw, outputs_bw)``
                containing the forward and the backward RNN outputs, each of
                which is of shape ``[batch_size, max_time, output_dim]``
                if ``time_major`` is `False`, or
                ``[max_time, batch_size, output_dim]`` if ``time_major``
                is `True`.
                If RNN cell output is a (nested) tuple of Tensors, then
                ``outputs_fw`` and ``outputs_bw`` will be a (nested) tuple
                having the same structure as the cell output.

              - :attr:`final_state`: A tuple
                ``(final_state_fw, final_state_bw)`` containing the final
                states of the forward and backward RNNs, each of which is a
                Tensor of shape ``[batch_size] + cell.state_size``, or a
                (nested) tuple of Tensors if ``cell.state_size`` is a (nested)
                tuple.

            - If ``return_cell_output`` is `True`, returns a triple
              :attr:`(outputs, final_state, cell_outputs)` where

              - :attr:`cell_outputs`: A tuple
                ``(cell_outputs_fw, cell_outputs_bw)`` containing the outputs
                by the forward and backward RNN cells prior to the output
                layers, having the same structure with :attr:`outputs` except
                for the ``output_dim``.

            - If ``return_output_size`` is `True`, returns a tuple
              :attr:`(outputs, final_state, output_size)` where

              - :attr:`output_size`: A tuple
                ``(output_size_fw, output_size_bw)`` containing the size of
                ``outputs_fw`` and ``outputs_bw``, respectively.
                Take ``*_fw`` for example, ``output_size_fw`` is a (possibly
                nested tuple of) int. If a single int or an int array, then
                ``outputs_fw`` has shape
                ``[batch/time, time/batch] + output_size_fw``. If a (nested)
                tuple, then ``output_size_fw`` has the same structure as
                ``outputs_fw``. The same applies to ``output_size_bw``.

            - If both ``return_cell_output`` and ``return_output_size`` are
              `True`, returns
              :attr:`(outputs, final_state, cell_outputs, output_size)`.
        """

        cell_outputs, states = bidirectional_dynamic_rnn(
            cell_fw=self._cell_fw,
            cell_bw=self._cell_bw,
            inputs=inputs,
            sequence_length=sequence_length,
            initial_state_fw=initial_state_fw,
            initial_state_bw=initial_state_bw,
            time_major=time_major)

        outputs_fw, output_size_fw = _forward_output_layers(
            inputs=cell_outputs[0],
            output_layer=self._output_layer_fw,
            time_major=time_major,
            sequence_length=sequence_length)

        outputs_bw, output_size_bw = _forward_output_layers(
            inputs=cell_outputs[1],
            output_layer=self._output_layer_bw,
            time_major=time_major,
            sequence_length=sequence_length)

        outputs = (outputs_fw, outputs_bw)
        output_size = (output_size_fw, output_size_bw)

        returns = (outputs, states)
        if return_cell_output:
            returns += (cell_outputs, )  # type: ignore
        if return_output_size:
            returns += (output_size, )  # type: ignore
        return returns
Esempio n. 2
0
    def test_bidirectional_dynamic_rnn_initial_state(self):
        r"""Tests :meth:`~texar.utils.rnn.bidirectional_dynamic_rnn`.
        """
        inputs = torch.rand(self._batch_size, self._max_time, self._input_size)

        rnn_initial_state_fw = torch.rand(self._batch_size, self._hidden_size)
        rnn_initial_state_bw = torch.rand(self._batch_size, self._hidden_size)
        lstm_initial_state_fw = (torch.rand(self._batch_size,
                                            self._hidden_size),
                                 torch.rand(self._batch_size,
                                            self._hidden_size))
        lstm_initial_state_bw = (torch.rand(self._batch_size,
                                            self._hidden_size),
                                 torch.rand(self._batch_size,
                                            self._hidden_size))

        # RNN
        outputs, output_state = bidirectional_dynamic_rnn(
            cell_fw=self._rnn_fw,
            cell_bw=self._rnn_bw,
            inputs=inputs,
            sequence_length=None,
            initial_state_fw=rnn_initial_state_fw,
            initial_state_bw=rnn_initial_state_bw,
            time_major=False)

        self.assertIsInstance(outputs, tuple)
        self.assertEqual(
            outputs[0].shape,
            torch.Size([self._batch_size, self._max_time, self._hidden_size]))
        self.assertEqual(
            outputs[1].shape,
            torch.Size([self._batch_size, self._max_time, self._hidden_size]))

        self.assertIsInstance(output_state, tuple)
        self.assertEqual(output_state[0].shape,
                         torch.Size([self._batch_size, self._hidden_size]))
        self.assertEqual(output_state[1].shape,
                         torch.Size([self._batch_size, self._hidden_size]))

        # LSTM
        outputs, output_state = bidirectional_dynamic_rnn(
            cell_fw=self._lstm_fw,
            cell_bw=self._lstm_bw,
            inputs=inputs,
            sequence_length=None,
            initial_state_fw=lstm_initial_state_fw,
            initial_state_bw=lstm_initial_state_bw,
            time_major=False)

        self.assertIsInstance(outputs, tuple)
        self.assertEqual(
            outputs[0].shape,
            torch.Size([self._batch_size, self._max_time, self._hidden_size]))
        self.assertEqual(
            outputs[1].shape,
            torch.Size([self._batch_size, self._max_time, self._hidden_size]))

        self.assertIsInstance(output_state, tuple)
        self.assertIsInstance(output_state[0], tuple)
        self.assertEqual(output_state[0][0].shape,
                         torch.Size([self._batch_size, self._hidden_size]))
        self.assertEqual(output_state[0][1].shape,
                         torch.Size([self._batch_size, self._hidden_size]))
        self.assertIsInstance(output_state[1], tuple)
        self.assertEqual(output_state[1][0].shape,
                         torch.Size([self._batch_size, self._hidden_size]))
        self.assertEqual(output_state[1][1].shape,
                         torch.Size([self._batch_size, self._hidden_size]))

        # GRU
        outputs, output_state = bidirectional_dynamic_rnn(
            cell_fw=self._gru_fw,
            cell_bw=self._gru_bw,
            inputs=inputs,
            sequence_length=None,
            initial_state_fw=rnn_initial_state_fw,
            initial_state_bw=rnn_initial_state_bw,
            time_major=False)

        self.assertIsInstance(outputs, tuple)
        self.assertEqual(
            outputs[0].shape,
            torch.Size([self._batch_size, self._max_time, self._hidden_size]))
        self.assertEqual(
            outputs[1].shape,
            torch.Size([self._batch_size, self._max_time, self._hidden_size]))

        self.assertIsInstance(output_state, tuple)
        self.assertEqual(output_state[0].shape,
                         torch.Size([self._batch_size, self._hidden_size]))
        self.assertEqual(output_state[1].shape,
                         torch.Size([self._batch_size, self._hidden_size]))