def forward( self, # type: ignore inputs: torch.Tensor, sequence_length: Optional[Union[torch.LongTensor, List[int]]] = None, initial_state_fw: Optional[State] = None, initial_state_bw: Optional[State] = None, time_major: bool = False, return_cell_output: bool = False, return_output_size: bool = False): r"""Encodes the inputs. Args: inputs: A 3D Tensor of shape ``[batch_size, max_time, dim]``. The first two dimensions ``batch_size`` and ``max_time`` may be exchanged if ``time_major`` is `True`. sequence_length (optional): A 1D :tensor:`LongTensor` of shape ``[batch_size]``. Sequence lengths of the batch inputs. Used to copy-through state and zero-out outputs when past a batch element's sequence length. initial_state_fw: (optional): Initial state of the forward RNN. initial_state_bw: (optional): Initial state of the backward RNN. time_major (bool): The shape format of the :attr:`inputs` and :attr:`outputs` Tensors. If `True`, these tensors are of shape ``[max_time, batch_size, depth]``. If `False` (default), these tensors are of shape ``[batch_size, max_time, depth]``. return_cell_output (bool): Whether to return the output of the RNN cell. This is the results prior to the output layer. return_output_size (bool): Whether to return the output size of the RNN cell. This is the results after the output layer. Returns: - By default (both ``return_cell_output`` and ``return_output_size`` are `False`), returns a pair :attr:`(outputs, final_state)` - :attr:`outputs`: A tuple ``(outputs_fw, outputs_bw)`` containing the forward and the backward RNN outputs, each of which is of shape ``[batch_size, max_time, output_dim]`` if ``time_major`` is `False`, or ``[max_time, batch_size, output_dim]`` if ``time_major`` is `True`. If RNN cell output is a (nested) tuple of Tensors, then ``outputs_fw`` and ``outputs_bw`` will be a (nested) tuple having the same structure as the cell output. - :attr:`final_state`: A tuple ``(final_state_fw, final_state_bw)`` containing the final states of the forward and backward RNNs, each of which is a Tensor of shape ``[batch_size] + cell.state_size``, or a (nested) tuple of Tensors if ``cell.state_size`` is a (nested) tuple. - If ``return_cell_output`` is `True`, returns a triple :attr:`(outputs, final_state, cell_outputs)` where - :attr:`cell_outputs`: A tuple ``(cell_outputs_fw, cell_outputs_bw)`` containing the outputs by the forward and backward RNN cells prior to the output layers, having the same structure with :attr:`outputs` except for the ``output_dim``. - If ``return_output_size`` is `True`, returns a tuple :attr:`(outputs, final_state, output_size)` where - :attr:`output_size`: A tuple ``(output_size_fw, output_size_bw)`` containing the size of ``outputs_fw`` and ``outputs_bw``, respectively. Take ``*_fw`` for example, ``output_size_fw`` is a (possibly nested tuple of) int. If a single int or an int array, then ``outputs_fw`` has shape ``[batch/time, time/batch] + output_size_fw``. If a (nested) tuple, then ``output_size_fw`` has the same structure as ``outputs_fw``. The same applies to ``output_size_bw``. - If both ``return_cell_output`` and ``return_output_size`` are `True`, returns :attr:`(outputs, final_state, cell_outputs, output_size)`. """ cell_outputs, states = bidirectional_dynamic_rnn( cell_fw=self._cell_fw, cell_bw=self._cell_bw, inputs=inputs, sequence_length=sequence_length, initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw, time_major=time_major) outputs_fw, output_size_fw = _forward_output_layers( inputs=cell_outputs[0], output_layer=self._output_layer_fw, time_major=time_major, sequence_length=sequence_length) outputs_bw, output_size_bw = _forward_output_layers( inputs=cell_outputs[1], output_layer=self._output_layer_bw, time_major=time_major, sequence_length=sequence_length) outputs = (outputs_fw, outputs_bw) output_size = (output_size_fw, output_size_bw) returns = (outputs, states) if return_cell_output: returns += (cell_outputs, ) # type: ignore if return_output_size: returns += (output_size, ) # type: ignore return returns
def test_bidirectional_dynamic_rnn_initial_state(self): r"""Tests :meth:`~texar.utils.rnn.bidirectional_dynamic_rnn`. """ inputs = torch.rand(self._batch_size, self._max_time, self._input_size) rnn_initial_state_fw = torch.rand(self._batch_size, self._hidden_size) rnn_initial_state_bw = torch.rand(self._batch_size, self._hidden_size) lstm_initial_state_fw = (torch.rand(self._batch_size, self._hidden_size), torch.rand(self._batch_size, self._hidden_size)) lstm_initial_state_bw = (torch.rand(self._batch_size, self._hidden_size), torch.rand(self._batch_size, self._hidden_size)) # RNN outputs, output_state = bidirectional_dynamic_rnn( cell_fw=self._rnn_fw, cell_bw=self._rnn_bw, inputs=inputs, sequence_length=None, initial_state_fw=rnn_initial_state_fw, initial_state_bw=rnn_initial_state_bw, time_major=False) self.assertIsInstance(outputs, tuple) self.assertEqual( outputs[0].shape, torch.Size([self._batch_size, self._max_time, self._hidden_size])) self.assertEqual( outputs[1].shape, torch.Size([self._batch_size, self._max_time, self._hidden_size])) self.assertIsInstance(output_state, tuple) self.assertEqual(output_state[0].shape, torch.Size([self._batch_size, self._hidden_size])) self.assertEqual(output_state[1].shape, torch.Size([self._batch_size, self._hidden_size])) # LSTM outputs, output_state = bidirectional_dynamic_rnn( cell_fw=self._lstm_fw, cell_bw=self._lstm_bw, inputs=inputs, sequence_length=None, initial_state_fw=lstm_initial_state_fw, initial_state_bw=lstm_initial_state_bw, time_major=False) self.assertIsInstance(outputs, tuple) self.assertEqual( outputs[0].shape, torch.Size([self._batch_size, self._max_time, self._hidden_size])) self.assertEqual( outputs[1].shape, torch.Size([self._batch_size, self._max_time, self._hidden_size])) self.assertIsInstance(output_state, tuple) self.assertIsInstance(output_state[0], tuple) self.assertEqual(output_state[0][0].shape, torch.Size([self._batch_size, self._hidden_size])) self.assertEqual(output_state[0][1].shape, torch.Size([self._batch_size, self._hidden_size])) self.assertIsInstance(output_state[1], tuple) self.assertEqual(output_state[1][0].shape, torch.Size([self._batch_size, self._hidden_size])) self.assertEqual(output_state[1][1].shape, torch.Size([self._batch_size, self._hidden_size])) # GRU outputs, output_state = bidirectional_dynamic_rnn( cell_fw=self._gru_fw, cell_bw=self._gru_bw, inputs=inputs, sequence_length=None, initial_state_fw=rnn_initial_state_fw, initial_state_bw=rnn_initial_state_bw, time_major=False) self.assertIsInstance(outputs, tuple) self.assertEqual( outputs[0].shape, torch.Size([self._batch_size, self._max_time, self._hidden_size])) self.assertEqual( outputs[1].shape, torch.Size([self._batch_size, self._max_time, self._hidden_size])) self.assertIsInstance(output_state, tuple) self.assertEqual(output_state[0].shape, torch.Size([self._batch_size, self._hidden_size])) self.assertEqual(output_state[1].shape, torch.Size([self._batch_size, self._hidden_size]))