def test_sequence_mask(self): r"""Tests :meth:`texar.torch.utils.sequence_mask`. """ mask1 = utils.sequence_mask([1, 3, 2], 5).numpy() expected1 = np.asarray([[True, False, False, False, False], [True, True, True, False, False], [True, True, False, False, False]]) np.testing.assert_array_equal(mask1, expected1) mask2 = utils.sequence_mask(torch.tensor([[1, 3], [2, 0]])) expected2 = np.asarray([[[True, False, False], [True, True, True]], [[True, True, False], [False, False, False]]]) np.testing.assert_array_equal(mask2, expected2)
def mask_sequences(sequence: Union[torch.Tensor, List[int]], sequence_length: Union[torch.LongTensor, List[int]], dtype: Optional[torch.dtype] = None, time_major: bool = False) -> torch.Tensor: r"""Masks out sequence entries that are beyond the respective sequence lengths. Masks along the time dimension. :attr:`sequence` and :attr:`sequence_length` can either be python arrays or Tensors, respectively. If both are Python arrays (or None), the return will be a Python array as well. Args: sequence: A Tensor or Python array of sequence values. If ``time_major==False`` (default), this must be a Tensor of shape ``[batch_size, max_time, ...]``. The batch and time dimension is exchanged if ``time_major==True``. sequence_length: A Tensor or python array of shape ``[batch_size]``. Time steps beyond the respective sequence lengths will be made zero. dtype (dtype): Type of :attr:`sequence`. If `None`, infer from :attr:`sequence` automatically. time_major (bool): The shape format of the inputs. If `True`, :attr:`sequence` must have shape ``[max_time, batch_size, ...]``. If `False` (default), :attr:`sequence` must have shape ``[batch_size, max_time, ...]``. Returns: The masked sequence, i.e., a Tensor or python array of the same shape as :attr:`sequence` but with masked-out entries (set to zero). If both :attr:`sequence` and :attr:`sequence_length` are python arrays, the returned value is a python array as well. """ if not torch.is_tensor(sequence): sequence = torch.tensor(sequence, dtype=dtype) sequence: torch.Tensor rank = sequence.dim() if rank < 2: raise ValueError("`sequence` must be 2D or higher order.") if time_major: sequence = transpose_batch_time(sequence) max_time = sequence.size(1) if dtype is None: dtype = sequence.dtype mask = utils.sequence_mask(sequence_length, max_time, dtype=dtype) mask = mask.view(*mask.size(), *([1] * (rank - 2))) sequence = sequence * mask if time_major: sequence = transpose_batch_time(sequence) return sequence
def maybe_mask_score(score: torch.Tensor, score_mask_value: torch.Tensor, memory_sequence_length: Optional[torch.LongTensor]) \ -> torch.Tensor: r"""Mask the attention score based on the masks.""" if memory_sequence_length is None: return score for memory_sequence_length_value in memory_sequence_length: if memory_sequence_length_value <= 0: raise ValueError( "All values in memory_sequence_length must be greater " "than zero.") score_mask = sequence_mask(memory_sequence_length, max_len=score.shape[1]) score_mask_values = score_mask_value * torch.ones_like(score) return torch.where(score_mask, score, score_mask_values)
def prepare_memory(memory: torch.Tensor, memory_sequence_length: Optional[torch.LongTensor]) \ -> torch.Tensor: r"""Convert to tensor and possibly mask ``memory``. Args: memory: tensor, shaped ``[batch_size, max_time, ...]``. memory_sequence_length: integer tensor, shaped ``[batch_size]``. Returns: A (possibly masked), new ``memory``. Raises: ValueError: if ``memory`` and ``memory_sequence_length`` do not have the same ``batch_size``. """ if (memory_sequence_length is not None and not isinstance(memory_sequence_length, torch.Tensor)): memory_sequence_length = torch.tensor(memory_sequence_length, dtype=torch.long, device=memory.device) if memory_sequence_length is None: seq_len_mask = None else: seq_len_mask = sequence_mask(memory_sequence_length, max_len=memory.shape[1], dtype=memory.dtype) seq_len_batch_size = memory_sequence_length.shape[0] # Mask the memory based on the memory mask. rank = memory.dim() m_batch_size = memory.shape[0] if seq_len_mask is not None: if seq_len_batch_size != m_batch_size: raise ValueError("memory_sequence_length and memory tensor " "batch sizes do not match.") return memory * seq_len_mask.view(seq_len_mask.size() + (1, ) * (rank - 2)) else: return memory
def _discount_reward_tensor_1d(reward: torch.Tensor, sequence_length: Optional[torch.LongTensor], discount: float = 1.) -> torch.Tensor: r"""Computes discounted reward. Args: reward: 1D Tensor with shape `[batch_size]`. sequence_length: A Tensor of shape `[batch_size]`. Time steps beyond the respective sequence lengths will be masked. discount (float): A scalar. The discount factor. Returns: A 2D Tensor of the discounted reward. """ if sequence_length is None: raise ValueError('sequence_length must not be `None` for 1D reward.') if not isinstance(sequence_length, torch.Tensor): sequence_length = torch.tensor(sequence_length, dtype=torch.int64, device=reward.device) batch_size = reward.shape[0] max_seq_length = torch.max(sequence_length) dtype: torch.dtype = reward.dtype if discount == 1.: disc_reward = reward.unsqueeze(-1).expand(batch_size, max_seq_length) else: mask = sequence_mask(sequence_length, dtype=dtype) mask = torch.cat((mask[:, 1:], torch.zeros_like(mask[:, -1:])), dim=1) # Make each row = [discount, ..., discount, 1, ..., 1] dmat = mask * discount + (1 - mask) dmat = torch.flip(dmat, (1, )) dmat = torch.cumprod(dmat, dim=1) dmat = torch.flip(dmat, (1, )) disc_reward = dmat * reward.unsqueeze(-1) disc_reward = mask_sequences(disc_reward, sequence_length, dtype=dtype) return disc_reward
def forward(self, # type: ignore inputs: Optional[torch.Tensor] = None, sequence_length: Optional[torch.LongTensor] = None, memory: Optional[torch.Tensor] = None, memory_sequence_length: Optional[torch.LongTensor] = None, memory_attention_bias: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None, context_sequence_length: Optional[torch.LongTensor] = None, helper: Optional[Helper] = None, decoding_strategy: str = 'train_greedy', max_decoding_length: Optional[int] = None, impute_finished: bool = False, infer_mode: Optional[bool] = None, beam_width: Optional[int] = None, length_penalty: float = 0., **kwargs) \ -> Union[ TransformerDecoderOutput, Tuple[TransformerDecoderOutput, torch.LongTensor], Dict[str, torch.Tensor]]: r"""Performs decoding. The interface is very similar to that of RNN decoders (:class:`~texar.torch.modules.RNNDecoderBase`). In particular, the function provides **3 ways** to specify the decoding method, with varying flexibility: 1. The :attr:`decoding_strategy` argument. - **"train_greedy"**: decoding in teacher-forcing fashion (i.e., feeding ground truth to decode the next step), and for each step sample is obtained by taking the `argmax` of logits. Argument :attr:`inputs` is required for this strategy. :attr:`sequence_length` is optional. - **"infer_greedy"**: decoding in inference fashion (i.e., feeding `generated` sample to decode the next step), and for each step sample is obtained by taking the `argmax` of logits. Arguments :attr:`(start_tokens, end_token)` are required for this strategy, and argument :attr:`max_decoding_length` is optional. - **"infer_sample"**: decoding in inference fashion, and for each step sample is obtained by `random sampling` from the logits. Arguments :attr:`(start_tokens, end_token)` are required for this strategy, and argument :attr:`max_decoding_length` is optional. This argument is used only when arguments :attr:`helper` and :attr:`beam_width` are both `None`. 2. The :attr:`helper` argument: An instance of subclass of :class:`~texar.torch.modules.Helper`. This provides a superset of decoding strategies than above. The interface is the same as in RNN decoders. Please refer to :meth:`texar.torch.modules.RNNDecoderBase.forward` for detailed usage and examples. Note that, here, though using a :class:`~texar.torch.modules.TrainingHelper` corresponding to the ``"train_greedy"`` strategy above, the implementation is *slower* than directly setting ``decoding_strategy="train_greedy"`` (though output results are the same). Argument :attr:`max_decoding_length` is optional. 3. **Beam search**: set :attr:`beam_width` to use beam search decoding. Arguments :attr:`(start_tokens, end_token)` are required, and argument :attr:`max_decoding_length` is optional. Args: memory (optional): The memory to attend, e.g., the output of an RNN encoder. A :tensor:`Tensor` of shape ``[batch_size, memory_max_time, dim]``. memory_sequence_length (optional): A :tensor:`Tensor` of shape ``[batch_size]`` containing the sequence lengths for the batch entries in memory. Used to create attention bias of :attr:`memory_attention_bias` is not given. Ignored if :attr:`memory_attention_bias` is provided. memory_attention_bias (optional): A :tensor:`Tensor` of shape ``[batch_size, num_heads, memory_max_time, dim]``. An attention bias typically sets the value of a padding position to a large negative value for masking. If not given, :attr:`memory_sequence_length` is used to automatically create an attention bias. inputs (optional): Input tensors for teacher forcing decoding. Used when :attr:`decoding_strategy` is set to ``"train_greedy"``, or when `hparams`-configured helper is used. The attr:`inputs` is a :tensor:`LongTensor` used as index to look up embeddings and feed in the decoder. For example, if :attr:`embedder` is an instance of :class:`~texar.torch.modules.WordEmbedder`, then :attr:`inputs` is usually a 2D int Tensor `[batch_size, max_time]` (or `[max_time, batch_size]` if `input_time_major` == `True`) containing the token indexes. sequence_length (optional): A :tensor:`LongTensor` of shape ``[batch_size]``, containing the sequence length of :attr:`inputs`. Tokens beyond the respective sequence length are masked out. Used when :attr:`decoding_strategy` is set to ``"train_greedy"``. decoding_strategy (str): A string specifying the decoding strategy, including ``"train_greedy"``, ``"infer_greedy"``, ``"infer_sample"``. Different arguments are required based on the strategy. See above for details. Ignored if :attr:`beam_width` or :attr:`helper` is set. beam_width (int): Set to use beam search. If given, :attr:`decoding_strategy` is ignored. length_penalty (float): Length penalty coefficient used in beam search decoding. Refer to https://arxiv.org/abs/1609.08144 for more details. It should be larger if longer sentences are desired. context (optional): An :tensor:`LongTensor` of shape ``[batch_size, length]``, containing the starting tokens for decoding. If context is set, ``start_tokens`` of the :class:`~texar.torch.modules.Helper` will be ignored. context_sequence_length (optional): Specify the length of context. max_decoding_length (int, optional): The maximum allowed number of decoding steps. If `None` (default), use ``"max_decoding_length"`` defined in :attr:`hparams`. Ignored in ``"train_greedy"`` decoding. impute_finished (bool): If `True`, then states for batch entries which are marked as finished get copied through and the corresponding outputs get zeroed out. This causes some slowdown at each time step, but ensures that the final state and outputs have the correct values and that backprop ignores time steps that were marked as finished. Ignored in ``"train_greedy"`` decoding. helper (optional): An instance of :class:`~texar.torch.modules.Helper` that defines the decoding strategy. If given, ``decoding_strategy`` and helper configurations in :attr:`hparams` are ignored. infer_mode (optional): If not `None`, overrides mode given by :attr:`self.training`. Returns: - For **"train_greedy"** decoding, returns an instance of :class:`~texar.torch.modules.TransformerDecoderOutput` which contains `sample_id` and `logits`. - For **"infer_greedy"** and **"infer_sample"** decoding or decoding with :attr:`helper`, returns a tuple ``(outputs, sequence_lengths)``, where ``outputs`` is an instance of :class:`~texar.torch.modules.TransformerDecoderOutput` as in `"train_greedy"`, and ``sequence_lengths`` is a :tensor:`LongTensor` of shape ``[batch_size]`` containing the length of each sample. - For **beam search** decoding, returns a ``dict`` containing keys ``"sample_id"`` and ``"log_prob"``. - ``"sample_id"`` is a :tensor:`LongTensor` of shape ``[batch_size, max_time, beam_width]`` containing generated token indexes. ``sample_id[:,:,0]`` is the highest-probable sample. - ``"log_prob"`` is a :tensor:`Tensor` of shape ``[batch_size, beam_width]`` containing the log probability of each sequence sample. """ if memory is not None: if memory_attention_bias is None: if memory_sequence_length is None: raise ValueError("`memory_sequence_length` is required if " "`memory_attention_bias` is not given.") enc_padding = 1 - sequence_mask(memory_sequence_length, memory.size(1), dtype=torch.float32) memory_attention_bias = attn.attention_bias_ignore_padding( enc_padding) # record the context, which will be used in step function # for dynamic_decode if context is not None: if context_sequence_length is None: raise ValueError("'context_sequence_length' must not be None" "when 'context' is specified.") self._state_context = context[:, 1:] self._state_context_sequence_length = context_sequence_length - 1 else: self._state_context = None self._state_context_sequence_length = None # Faster code path for teacher-forcing training if (helper is None and beam_width is None and decoding_strategy == 'train_greedy'): if inputs is None: raise ValueError( "'input' must not be none " "when using 'train_greedy' decoding strategy.") times = torch.arange(inputs.size(1), dtype=torch.long, device=inputs.device) times = times.unsqueeze(0).expand(inputs.size(0), -1) inputs = self.embed_tokens(inputs, times) if sequence_length is not None: inputs = mask_sequences(inputs, sequence_length) decoder_self_attention_bias = (attn.attention_bias_lower_triangle( inputs.size(1))) decoder_output = self._self_attention_stack( inputs, memory, decoder_self_attention_bias, memory_attention_bias, cache=None) logits = self._output_layer(decoder_output) sample_id = torch.argmax(logits, dim=-1) return TransformerDecoderOutput(logits, sample_id) # Inference code path. if max_decoding_length is None: max_decoding_length = self._hparams.max_decoding_length self._state_max_decoding_length = max_decoding_length if beam_width is None or beam_width == 1: # Inference-like decoding # Prepare helper if helper is None: kwargs.update(decoding_strategy=decoding_strategy) if context is not None: kwargs.update(start_tokens=context[:, 0]) helper = self._create_or_get_helper(infer_mode, **kwargs) assert isinstance(helper, EmbeddingHelper) self._state_cache = self._init_cache(memory, memory_attention_bias, beam_search_decoding=False, batch_size=helper.batch_size) if context is not None: assert self._state_context is not None pad_length = max_decoding_length - self._state_context.size(1) if pad_length > 0: self._state_context = torch.cat( (self._state_context, self._state_context.new_zeros( self._state_context.size(0), pad_length)), dim=1) outputs, cache, sequence_lengths = self.dynamic_decode( helper, inputs=None, sequence_length=None, initial_state=None, max_decoding_length=max_decoding_length, impute_finished=impute_finished) del cache # not used if context is not None: # Here the length of sample_id will be larger than that # of logit by 1, because there will be a additional # start_token in the returned sample_id. # the start_id should be the first token of the # given context start_tokens = context[:, 0] outputs = TransformerDecoderOutput( logits=outputs.logits, sample_id=torch.cat( [start_tokens.unsqueeze(1), outputs.sample_id], dim=1)) sequence_lengths = sequence_lengths + 1 return outputs, sequence_lengths else: # Beam-search decoding # Ignore `decoding_strategy` and # assume `helper` is not set. if helper is not None: raise ValueError("Must not set 'beam_width' and 'helper' " "simultaneously.") if context is not None: start_tokens = context[:, 0] else: if 'start_tokens' not in kwargs: raise ValueError( "'start_tokens' must be specified when using" "beam search decoding.") start_tokens = kwargs['start_tokens'] _batch_size = start_tokens.size(0) self._state_cache = self._init_cache(memory, memory_attention_bias, beam_search_decoding=True, batch_size=_batch_size) end_token: int = kwargs.get('end_token') # type: ignore # The output format is different when running beam search. sample_id, log_prob = self.beam_decode( start_tokens, end_token, embedding_fn=self.embed_tokens, beam_width=beam_width, length_penalty=length_penalty, decode_length=max_decoding_length) return {'sample_id': sample_id, 'log_prob': log_prob}
def forward( self, # type: ignore inputs: torch.Tensor, sequence_length: torch.LongTensor) -> torch.Tensor: r"""Encodes the inputs. Args: inputs: A 3D Tensor of shape ``[batch_size, max_time, dim]``, containing the embedding of input sequences. Note that the embedding dimension `dim` must equal "dim" in :attr:`hparams`. The input embedding is typically an aggregation of word embedding and position embedding. sequence_length: A 1D :tensor:`LongTensor` of shape ``[batch_size]``. Input tokens beyond respective sequence lengths are masked out automatically. Returns: A Tensor of shape ``[batch_size, max_time, dim]`` containing the encoded vectors. """ # Multiply input embedding with the sqrt of its dimension for # normalization inputs_padding = 1 - sequence_mask(sequence_length, inputs.size()[1]).float() if self._hparams.use_bert_config: ignore_padding = attn.attention_bias_ignore_padding( inputs_padding, bias_value=-1e4) else: ignore_padding = attn.attention_bias_ignore_padding(inputs_padding) encoder_self_attention_bias = ignore_padding input_embedding = inputs if self._hparams.use_bert_config: x = self.input_normalizer(input_embedding) x = self.embed_dropout(x) else: x = self.embed_dropout(input_embedding) for i in range(self._hparams.num_blocks): # trivial difference between BERT and original Transformer if self._hparams.use_bert_config: _queries_input = x else: _queries_input = self.self_attn_layer_norm[i](x) attention_output = self.self_attns[i]( queries=_queries_input, memory=_queries_input, memory_attention_bias=encoder_self_attention_bias, ) attention_output = self.residual_dropout(attention_output) x = x + attention_output poswise_network = self.poswise_networks[i] poswise_normalizer = self.poswise_layer_norm[i] if self._hparams.use_bert_config: x = poswise_normalizer(x) y = x else: y = poswise_normalizer(x) original_shape = y.size() y = y.view(-1, self._hparams.dim) layer_output = poswise_network(y) sub_output = self.residual_dropout(layer_output) sub_output = sub_output.view(original_shape) x = x + sub_output if self._hparams.use_bert_config: x = self.output_layer_norm[i](x) if not self._hparams.use_bert_config: x = self.final_layer_norm(x) return x