def _test_outputs(self, decoder, outputs, final_state, sequence_lengths, test_mode=False): hidden_size = decoder.hparams.rnn_cell.kwargs.num_units cell_type = decoder.hparams.rnn_cell.type is_multi = decoder.hparams.rnn_cell.num_layers > 1 self.assertIsInstance(outputs, AttentionRNNDecoderOutput) max_time = (self._max_time if not test_mode else max(sequence_lengths).item()) self.assertEqual( outputs.logits.shape, (self._batch_size, max_time, self._vocab_size)) if not test_mode: np.testing.assert_array_equal( sequence_lengths, [max_time] * self._batch_size) map_structure( lambda t: self.assertEqual( t.size(), (self._batch_size, hidden_size)), final_state.cell_state) state = final_state.cell_state if is_multi: self.assertIsInstance(state, list) state = state[0] if cell_type == "LSTMCell": self.assertIsInstance(state, tuple) state = state[0] self.assertIsInstance(state, torch.Tensor)
def move_memory(data, device): def _move_fn(x): if isinstance(x, torch.Tensor): return x.to(device=device, non_blocking=True) return x if isinstance(data, Batch): return Batch(len(data), batch={ key: map_structure(_move_fn, value) for key, value in data.items() }) return map_structure(_move_fn, data)
def forward(self, # type: ignore input: torch.Tensor, state: Optional[State] = None) \ -> Tuple[torch.Tensor, State]: if self.training and self._variational_recurrent: # Create or check recurrent masks. batch_size = input.size(0) for name, size in [('input', self.input_size), ('output', self.hidden_size), ('state', self.hidden_size)]: prob = getattr(self, f'_{name}_keep_prob') if prob == 1.0: continue mask = getattr(self, f'_recurrent_{name}_mask') if mask is None: # Initialize the mask according to current batch size. mask = self._new_mask(batch_size, size, prob) setattr(self, f'_recurrent_{name}_mask', mask) else: # Check that size matches. if mask.size(0) != batch_size: raise ValueError( "Variational recurrent dropout mask does not " "support variable batch sizes across time steps") input = self._dropout(input, self._input_keep_prob, self._recurrent_input_mask) output, new_state = super().forward(input, state) output = self._dropout(output, self._output_keep_prob, self._recurrent_output_mask) new_state = utils.map_structure( lambda x: self._dropout( x, self._state_keep_prob, self._recurrent_state_mask), new_state) return output, new_state
def forward(self, # type: ignore batch_size: Union[int, torch.Tensor]) -> Any: r"""Creates output tensor(s) that has the given value. Args: batch_size: An ``int`` or ``int`` scalar ``Tensor``, the batch size. value (optional): A scalar, the value that the output tensor(s) has. If ``None``, ``"value"`` in :attr:`hparams` is used. :returns: A (structure of) ``Tensor`` whose structure is the same as :attr:`output_size`, with value specified by ``value`` or :attr:`hparams`. """ def full_tensor(x): if isinstance(x, torch.Size): return torch.full((batch_size,) + x, self.value) else: return torch.full((batch_size, x), self.value) output = utils.map_structure( full_tensor, self._output_size) return output
def _dynamic_rnn_loop(cell: RNNCellBase[State], inputs: torch.Tensor, initial_state: State, sequence_length: torch.LongTensor) \ -> Tuple[torch.Tensor, State]: r"""Internal implementation of Dynamic RNN. Args: cell: An instance of RNNCell. inputs: A ``Tensor`` of shape ``[time, batch_size, input_size]``, or a nested tuple of such elements. initial_state: A ``Tensor`` of shape ``[batch_size, state_size]``, or if ``cell.state_size`` is a tuple, then this should be a tuple of tensors having shapes ``[batch_size, s]`` for ``s`` in ``cell.state_size``. sequence_length: (optional) An ``int32`` ``Tensor`` of shape ``[batch_size]``. Returns: Tuple ``(final_outputs, final_state)``. final_outputs: A ``Tensor`` of shape ``[time, batch_size, cell.output_size]``. If ``cell.output_size`` is a (possibly nested) tuple of ints or ``torch.Size`` objects, then this returns a (possibly nested) tuple of Tensors matching the corresponding shapes. final_state: A ``Tensor``, or possibly nested tuple of Tensors, matching in length and shapes to ``initial_state``. """ state = initial_state time_steps = inputs.shape[0] all_outputs = [] all_state = map_structure(lambda _: no_map(list), state) for i in range(time_steps): output, state = cell(inputs[i], state) all_outputs.append(output) map_structure_zip(lambda xs, x: xs.append(x), (all_state, state)) # TODO: Do not compute everything regardless of sequence_length final_outputs = torch.stack(all_outputs, dim=0) final_outputs = mask_sequences(final_outputs, sequence_length=sequence_length, time_major=True) final_state = map_structure(lambda _: no_map(list), state) # pylint: disable=cell-var-from-loop # Our use case is fine because the function is called immediately and # exclusively in the current iteration of the loop. for batch_idx, time_idx in enumerate(sequence_length.tolist()): if time_idx > 0: map_structure_zip( lambda xs, x: xs.append(x[time_idx - 1][batch_idx]), (final_state, all_state)) else: map_structure_zip(lambda xs, x: xs.append(x[batch_idx]), (final_state, initial_state)) # pylint: enable=cell-var-from-loop final_state = map_structure(lambda x: torch.stack(x, dim=0), final_state) return final_outputs, final_state
def dynamic_decode(self, helper: Helper, inputs: Optional[torch.Tensor], sequence_length: Optional[torch.LongTensor], initial_state: Optional[State], max_decoding_length: Optional[int] = None, impute_finished: bool = False, step_hook: Optional[Callable[[int], None]] = None) \ -> Tuple[Output, Optional[State], torch.LongTensor]: r"""Generic routine for dynamic decoding. Please check the `documentation <https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/dynamic_decode>`_ for the TensorFlow counterpart. Returns: A tuple of output, final state, and sequence lengths. Note that final state could be `None`, when all sequences are of zero length and :attr:`initial_state` is also `None`. """ # Decode finished, step_inputs, state = self.initialize(helper, inputs, sequence_length, initial_state) zero_outputs = step_inputs.new_zeros(step_inputs.size(0), self.output_size) if max_decoding_length is not None: finished |= (max_decoding_length <= 0) sequence_lengths = torch.zeros_like(finished, dtype=torch.long, device=finished.device) time = 0 outputs = [] while (not torch.all(finished).item() and (max_decoding_length is None or time < max_decoding_length)): (next_outputs, decoder_state, next_inputs, decoder_finished) = self.step(helper, time, step_inputs, state) if getattr(self, 'tracks_own_finished', False): next_finished = decoder_finished else: next_finished = decoder_finished | finished # Zero out output values past finish if impute_finished: emit = utils.map_structure_zip( lambda new, cur: torch.where(finished, cur, new), (next_outputs, zero_outputs)) next_state = utils.map_structure_zip( lambda new, cur: torch.where(finished, cur, new), (decoder_state, state)) else: emit = next_outputs next_state = decoder_state outputs.append(emit) sequence_lengths.index_fill_(dim=0, value=time + 1, index=torch.nonzero( (~finished).long()).flatten()) time += 1 finished = next_finished step_inputs = next_inputs state = next_state if step_hook is not None: step_hook(time) final_outputs = utils.map_structure_zip( lambda *tensors: torch.stack(tensors), outputs) # output at each time step may be a namedtuple final_state = state final_sequence_lengths = sequence_lengths try: final_outputs, final_state = self.finalize(final_outputs, final_state, final_sequence_lengths) except NotImplementedError: pass if not self._output_time_major: final_outputs = utils.map_structure( lambda x: x.transpose(0, 1) if x.dim() >= 2 else x, final_outputs) return final_outputs, final_state, final_sequence_lengths
def test_get_rnn_cell(self): r"""Tests the HParams class. """ input_size = 10 hparams = { 'type': 'LSTMCell', 'kwargs': { 'num_units': 20, 'forget_bias': 1.0, }, 'num_layers': 3, 'dropout': { 'input_keep_prob': 0.5, 'output_keep_prob': 0.5, 'state_keep_prob': 0.5, 'variational_recurrent': True }, 'residual': True, 'highway': True, } hparams = HParams(hparams, default_rnn_cell_hparams()) rnn_cell = get_rnn_cell(input_size, hparams) self.assertIsInstance(rnn_cell, wrappers.MultiRNNCell) self.assertEqual(len(rnn_cell._cell), hparams.num_layers) self.assertEqual(rnn_cell.input_size, input_size) self.assertEqual(rnn_cell.hidden_size, hparams.kwargs.num_units) for idx, cell in enumerate(rnn_cell._cell): layer_input_size = (input_size if idx == 0 else hparams.kwargs.num_units) self.assertEqual(cell.input_size, layer_input_size) self.assertEqual(cell.hidden_size, hparams.kwargs.num_units) if idx > 0: highway = cell residual = highway._cell dropout = residual._cell self.assertIsInstance(highway, wrappers.HighwayWrapper) self.assertIsInstance(residual, wrappers.ResidualWrapper) else: dropout = cell lstm = dropout._cell builtin_lstm = lstm._cell self.assertIsInstance(dropout, wrappers.DropoutWrapper) self.assertIsInstance(lstm, wrappers.LSTMCell) self.assertIsInstance(builtin_lstm, nn.LSTMCell) h = hparams.kwargs.num_units forget_bias = builtin_lstm.bias_ih[h:(2 * h)] self.assertTrue((forget_bias == hparams.kwargs.forget_bias).all()) for key in ['input', 'output', 'state']: self.assertEqual(getattr(dropout, f'_{key}_keep_prob'), hparams.dropout[f'{key}_keep_prob']) self.assertTrue(dropout._variational_recurrent) batch_size = 8 seq_len = 6 state = None for step in range(seq_len): input = torch.zeros(batch_size, input_size) output, state = rnn_cell(input, state) self.assertEqual(output.shape, (batch_size, hparams.kwargs.num_units)) self.assertEqual(len(state), hparams.num_layers) utils.map_structure( lambda s: self.assertEqual(s.shape, (batch_size, hparams.kwargs .num_units)), state)