Ejemplo n.º 1
0
    def initialize(  # type: ignore
            self, helper: Helper,
            inputs: Optional[torch.Tensor],
            sequence_length: Optional[torch.LongTensor],
            initial_state: Optional[MaybeList[MaybeTuple[torch.Tensor]]]) -> \
            Tuple[torch.ByteTensor, torch.Tensor,
                  Optional[AttentionWrapperState]]:
        initial_finished, initial_inputs = helper.initialize(
            inputs, sequence_length)
        if initial_state is None:
            state = None
        else:
            batch_size = None

            def _get_batch_size_fn(x):
                nonlocal batch_size
                if isinstance(x, torch.Tensor):
                    batch_size = x.size(0)

            utils.map_structure(_get_batch_size_fn, initial_state)
            state = self._cell.zero_state(  # type: ignore
                batch_size=batch_size)
            state = state._replace(cell_state=initial_state)

        return initial_finished, initial_inputs, state
Ejemplo n.º 2
0
    def _test_outputs(self,
                      decoder,
                      outputs,
                      final_state,
                      sequence_lengths,
                      test_mode=False):
        hidden_size = decoder.hparams.rnn_cell.kwargs.num_units
        cell_type = decoder.hparams.rnn_cell.type
        is_multi = decoder.hparams.rnn_cell.num_layers > 1

        self.assertIsInstance(outputs, AttentionRNNDecoderOutput)
        max_time = (self._max_time
                    if not test_mode else max(sequence_lengths).item())
        self.assertEqual(outputs.logits.shape,
                         (self._batch_size, max_time, self._vocab_size))
        if not test_mode:
            np.testing.assert_array_equal(sequence_lengths,
                                          [max_time] * self._batch_size)

        map_structure(
            lambda t: self.assertEqual(t.size(),
                                       (self._batch_size, hidden_size)),
            final_state.cell_state)
        state = final_state.cell_state
        if is_multi:
            self.assertIsInstance(state, list)
            state = state[0]
        if cell_type == "LSTMCell":
            self.assertIsInstance(state, tuple)
            state = state[0]
        self.assertIsInstance(state, torch.Tensor)
Ejemplo n.º 3
0
    def forward(self,  # type: ignore
                input: torch.Tensor, state: Optional[State] = None) \
            -> Tuple[torch.Tensor, State]:
        if self.training and self._variational_recurrent:
            # Create or check recurrent masks.
            batch_size = input.size(0)
            for name, size in [('input', self.input_size),
                               ('output', self.hidden_size),
                               ('state', self.hidden_size)]:
                prob = getattr(self, f'_{name}_keep_prob')
                if prob == 1.0:
                    continue
                mask = getattr(self, f'_recurrent_{name}_mask')
                if mask is None:
                    # Initialize the mask according to current batch size.
                    mask = self._new_mask(batch_size, size, prob)
                    setattr(self, f'_recurrent_{name}_mask', mask)
                else:
                    # Check that size matches.
                    if mask.size(0) != batch_size:
                        raise ValueError(
                            "Variational recurrent dropout mask does not "
                            "support variable batch sizes across time steps")

        input = self._dropout(input, self._input_keep_prob,
                              self._recurrent_input_mask)
        output, new_state = super().forward(input, state)
        output = self._dropout(output, self._output_keep_prob,
                               self._recurrent_output_mask)
        new_state = utils.map_structure(
            lambda x: self._dropout(
                x, self._state_keep_prob, self._recurrent_state_mask),
            new_state)
        return output, new_state
Ejemplo n.º 4
0
    def forward(self,    # type: ignore
                batch_size: Union[int, torch.Tensor]) -> Any:
        r"""Creates output tensor(s) that has the given value.

        Args:
            batch_size: An ``int`` or ``int`` scalar ``Tensor``, the
                batch size.
            value (optional): A scalar, the value that the output tensor(s)
                has. If ``None``, ``"value"`` in :attr:`hparams` is used.

        :returns:
            A (structure of) ``Tensor`` whose structure is the same as
            :attr:`output_size`, with value specified by
            ``value`` or :attr:`hparams`.
        """

        def full_tensor(x):
            if isinstance(x, torch.Size):
                return torch.full((batch_size,) + x, self.value)
            else:
                return torch.full((batch_size, x), self.value)

        output = utils.map_structure(
            full_tensor,
            self._output_size)

        return output
Ejemplo n.º 5
0
def _dynamic_rnn_loop(cell: RNNCellBase[State],
                      inputs: torch.Tensor,
                      initial_state: State,
                      sequence_length: torch.LongTensor) \
        -> Tuple[torch.Tensor, State]:
    r"""Internal implementation of Dynamic RNN.

    Args:
        cell: An instance of RNNCell.
        inputs: A ``Tensor`` of shape ``[time, batch_size, input_size]``,
            or a nested tuple of such elements.
        initial_state: A ``Tensor`` of shape ``[batch_size, state_size]``,
            or if ``cell.state_size`` is a tuple, then this should be a tuple
            of tensors having shapes ``[batch_size, s]`` for ``s`` in
            ``cell.state_size``.
        sequence_length: (optional) An ``int32`` ``Tensor``
            of shape ``[batch_size]``.

    Returns:
        Tuple ``(final_outputs, final_state)``.
        final_outputs:
            A ``Tensor`` of shape ``[time, batch_size, cell.output_size]``. If
            ``cell.output_size`` is a (possibly nested) tuple of ints or
            ``torch.Size`` objects, then this returns a
            (possibly nested) tuple of Tensors matching the corresponding
            shapes.
        final_state:
            A ``Tensor``, or possibly nested tuple of Tensors, matching
            in length and shapes to ``initial_state``.
    """
    state = initial_state
    time_steps = inputs.shape[0]
    all_outputs = []

    all_state = map_structure(lambda _: no_map(list), state)

    for i in range(time_steps):
        output, state = cell(inputs[i], state)
        all_outputs.append(output)
        map_structure_zip(lambda xs, x: xs.append(x), (all_state, state))
    # TODO: Do not compute everything regardless of sequence_length

    final_outputs = torch.stack(all_outputs, dim=0)
    final_outputs = mask_sequences(final_outputs,
                                   sequence_length=sequence_length,
                                   time_major=True)

    final_state = map_structure(lambda _: no_map(list), state)
    # pylint: disable=cell-var-from-loop
    # Our use case is fine because the function is called immediately and
    # exclusively in the current iteration of the loop.
    for batch_idx, time_idx in enumerate(sequence_length.tolist()):
        if time_idx > 0:
            map_structure_zip(
                lambda xs, x: xs.append(x[time_idx - 1][batch_idx]),
                (final_state, all_state))
        else:
            map_structure_zip(lambda xs, x: xs.append(x[batch_idx]),
                              (final_state, initial_state))
    # pylint: enable=cell-var-from-loop

    final_state = map_structure(lambda x: torch.stack(x, dim=0), final_state)

    return final_outputs, final_state
Ejemplo n.º 6
0
    def dynamic_decode(self, helper: Helper, inputs: Optional[torch.Tensor],
                       sequence_length: Optional[torch.LongTensor],
                       initial_state: Optional[State],
                       max_decoding_length: Optional[int] = None,
                       impute_finished: bool = False) \
            -> Tuple[Output, Optional[State], torch.LongTensor]:
        r"""Generic routine for dynamic decoding. Please check the
        `documentation
        <https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/dynamic_decode>`_
        for the TensorFlow counterpart.

        Returns:
            A tuple of output, final state, and sequence lengths. Note that
            final state could be ``None``, when all sequences are of zero length
            and :attr:`initial_state` is also ``None``.
        """

        # Decode
        finished, step_inputs, state = self.initialize(helper, inputs,
                                                       sequence_length,
                                                       initial_state)

        zero_outputs = step_inputs.new_zeros(step_inputs.size(0),
                                             self.output_size)

        if max_decoding_length is not None:
            finished |= (max_decoding_length <= 0)
        sequence_lengths = torch.zeros_like(finished,
                                            dtype=torch.long,
                                            device=finished.device)
        time = 0

        outputs = []

        while (not torch.all(finished).item() and
               (max_decoding_length is None or time < max_decoding_length)):
            (next_outputs, decoder_state, next_inputs,
             decoder_finished) = self.step(helper, time, step_inputs, state)

            if getattr(self, 'tracks_own_finished', False):
                next_finished = decoder_finished
            else:
                next_finished = decoder_finished | finished

            # Zero out output values past finish
            if impute_finished:
                emit = utils.map_structure_zip(
                    lambda new, cur: torch.where(finished, cur, new),
                    (next_outputs, zero_outputs))
                next_state = utils.map_structure_zip(
                    lambda new, cur: torch.where(finished, cur, new),
                    (decoder_state, state))
            else:
                emit = next_outputs
                next_state = decoder_state

            outputs.append(emit)
            sequence_lengths.index_fill_(dim=0,
                                         value=time + 1,
                                         index=torch.nonzero(
                                             (~finished).long()).flatten())
            time += 1
            finished = next_finished
            step_inputs = next_inputs
            state = next_state

        final_outputs = utils.map_structure_zip(
            lambda *tensors: torch.stack(tensors),
            outputs)  # output at each time step may be a namedtuple
        final_state = state
        final_sequence_lengths = sequence_lengths

        try:
            final_outputs, final_state = self.finalize(final_outputs,
                                                       final_state,
                                                       final_sequence_lengths)
        except NotImplementedError:
            pass

        if not self._output_time_major:
            final_outputs = utils.map_structure(
                lambda x: x.transpose(0, 1)
                if x.dim() >= 2 else x, final_outputs)

        return final_outputs, final_state, final_sequence_lengths
Ejemplo n.º 7
0
    def test_get_rnn_cell(self):
        r"""Tests the HParams class.
        """
        input_size = 10
        hparams = {
            'type': 'LSTMCell',
            'kwargs': {
                'num_units': 20,
                'forget_bias': 1.0,
            },
            'num_layers': 3,
            'dropout': {
                'input_keep_prob': 0.5,
                'output_keep_prob': 0.5,
                'state_keep_prob': 0.5,
                'variational_recurrent': True
            },
            'residual': True,
            'highway': True,
        }
        hparams = HParams(hparams, default_rnn_cell_hparams())

        rnn_cell = get_rnn_cell(input_size, hparams)
        self.assertIsInstance(rnn_cell, wrappers.MultiRNNCell)
        self.assertEqual(len(rnn_cell._cell), hparams.num_layers)
        self.assertEqual(rnn_cell.input_size, input_size)
        self.assertEqual(rnn_cell.hidden_size, hparams.kwargs.num_units)

        for idx, cell in enumerate(rnn_cell._cell):
            layer_input_size = (input_size
                                if idx == 0 else hparams.kwargs.num_units)
            self.assertEqual(cell.input_size, layer_input_size)
            self.assertEqual(cell.hidden_size, hparams.kwargs.num_units)

            if idx > 0:
                highway = cell
                residual = highway._cell
                dropout = residual._cell
                self.assertIsInstance(highway, wrappers.HighwayWrapper)
                self.assertIsInstance(residual, wrappers.ResidualWrapper)
            else:
                dropout = cell
            lstm = dropout._cell
            builtin_lstm = lstm._cell
            self.assertIsInstance(dropout, wrappers.DropoutWrapper)
            self.assertIsInstance(lstm, wrappers.LSTMCell)
            self.assertIsInstance(builtin_lstm, nn.LSTMCell)
            h = hparams.kwargs.num_units
            forget_bias = builtin_lstm.bias_ih[h:(2 * h)]
            self.assertTrue((forget_bias == hparams.kwargs.forget_bias).all())

            for key in ['input', 'output', 'state']:
                self.assertEqual(getattr(dropout, f'_{key}_keep_prob'),
                                 hparams.dropout[f'{key}_keep_prob'])
            self.assertTrue(dropout._variational_recurrent)

        batch_size = 8
        seq_len = 6
        state = None
        for step in range(seq_len):
            input = torch.zeros(batch_size, input_size)
            output, state = rnn_cell(input, state)
            self.assertEqual(output.shape,
                             (batch_size, hparams.kwargs.num_units))
            self.assertEqual(len(state), hparams.num_layers)
            utils.map_structure(
                lambda s: self.assertEqual(s.shape, (batch_size, hparams.kwargs
                                                     .num_units)), state)