def check_hidden_size( self, hx: Tensor, expected_hidden_size: Tuple[int, int, int], msg: str = 'Expected hidden size {}, got {}') -> None: if hx.size() != expected_hidden_size: raise RuntimeError( msg.format(expected_hidden_size, tuple(hx.size())))
def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: expected_input_dim = 2 if batch_sizes is not None else 3 if input.dim() != expected_input_dim: raise RuntimeError('input must have {} dimensions, got {}'.format( expected_input_dim, input.dim())) if self.input_size != input.size(-1): raise RuntimeError( 'input.size(-1) must be equal to input_size. Expected {}, got {}' .format(self.input_size, input.size(-1)))
def forward(self, input: Tensor, target: Tensor) -> _ASMoutput: if input.size(0) != target.size(0): raise RuntimeError('Input and target should have the same size ' 'in the batch dimension.') used_rows = 0 batch_size = target.size(0) output = input.new_zeros(batch_size) gather_inds = target.new_empty(batch_size) cutoff_values = [0] + self.cutoffs for i in range(len(cutoff_values) - 1): low_idx = cutoff_values[i] high_idx = cutoff_values[i + 1] target_mask = (target >= low_idx) & (target < high_idx) row_indices = target_mask.nonzero().squeeze() if row_indices.numel() == 0: continue if i == 0: gather_inds.index_copy_(0, row_indices, target[target_mask]) else: relative_target = target[target_mask] - low_idx input_subset = input.index_select(0, row_indices) cluster_output = self.tail[i - 1](input_subset) cluster_index = self.shortlist_size + i - 1 gather_inds.index_fill_(0, row_indices, cluster_index) cluster_logprob = log_softmax(cluster_output, dim=1) local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1)) output.index_copy_(0, row_indices, local_logprob.squeeze(1)) used_rows += row_indices.numel() if used_rows != batch_size: raise RuntimeError("Target values should be in [0, {}], " "but values in range [{}, {}] " "were found. ".format(self.n_classes - 1, target.min().item(), target.max().item())) head_output = self.head(input) head_logprob = log_softmax(head_output, dim=1) output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze() loss = (-output).mean() return _ASMoutput(output, loss)
def get_expected_hidden_size( self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]: if batch_sizes is not None: mini_batch = batch_sizes[0] mini_batch = int(mini_batch) else: mini_batch = input.size(0) if self.batch_first else input.size(1) num_directions = 2 if self.bidirectional else 1 expected_hidden_size = (self.num_layers * num_directions, mini_batch, self.hidden_size) return expected_hidden_size
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: self.check_forward_input(input) if hx is None: hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) self.check_forward_hidden(input, hx, '') if self.nonlinearity == "tanh": ret = _VF.rnn_tanh_cell( input, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh, ) elif self.nonlinearity == "relu": ret = _VF.rnn_relu_cell( input, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh, ) else: ret = input # TODO: remove when jit supports exception flow raise RuntimeError("Unknown nonlinearity: {}".format( self.nonlinearity)) return ret
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: is_packed = isinstance(input, PackedSequence) if is_packed: input, batch_sizes, sorted_indices, unsorted_indices = input max_batch_size = batch_sizes[0] max_batch_size = int(max_batch_size) else: batch_sizes = None max_batch_size = input.size(0) if self.batch_first else input.size( 1) sorted_indices = None unsorted_indices = None if hx is None: num_directions = 2 if self.bidirectional else 1 hx = torch.zeros(self.num_layers * num_directions, max_batch_size, self.hidden_size, dtype=input.dtype, device=input.device) else: # Each batch of the hidden state should match the input sequence that # the user believes he/she is passing in. hx = self.permute_hidden(hx, sorted_indices) self.check_forward_args(input, hx, batch_sizes) _impl = _rnn_impls[self.mode] if batch_sizes is None: result = _impl(input, hx, self._flat_weights, self.bias, self.num_layers, self.dropout, self.training, self.bidirectional, self.batch_first) else: result = _impl(input, batch_sizes, hx, self._flat_weights, self.bias, self.num_layers, self.dropout, self.training, self.bidirectional) output = result[0] hidden = result[1] if is_packed: output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) return output, self.permute_hidden(hidden, unsorted_indices)
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: self.check_forward_input(input) if hx is None: hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) self.check_forward_hidden(input, hx, '') return _VF.gru_cell( input, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh, )
def check_forward_hidden(self, input: Tensor, hx: Tensor, hidden_label: str = '') -> None: if input.size(0) != hx.size(0): raise RuntimeError( "Input batch size {} doesn't match hidden{} batch size {}". format(input.size(0), hidden_label, hx.size(0))) if hx.size(1) != self.hidden_size: raise RuntimeError( "hidden{} has inconsistent hidden_size: got {}, expected {}". format(hidden_label, hx.size(1), self.hidden_size))
def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]: self.check_forward_input(input) if hx is None: zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) hx = (zeros, zeros) self.check_forward_hidden(input, hx[0], '[0]') self.check_forward_hidden(input, hx[1], '[1]') return _VF.lstm_cell( input, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh, )
def check_forward_input(self, input: Tensor) -> None: if input.size(1) != self.input_size: raise RuntimeError( "input has inconsistent input_size: got {}, expected {}". format(input.size(1), self.input_size))
def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: r"""Take in and process masked source/target sequences. Args: src: the sequence to the encoder (required). tgt: the sequence to the decoder (required). src_mask: the additive mask for the src sequence (optional). tgt_mask: the additive mask for the tgt sequence (optional). memory_mask: the additive mask for the encoder output (optional). src_key_padding_mask: the ByteTensor mask for src keys per batch (optional). tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional). memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional). Shape: - src: :math:`(S, N, E)`. - tgt: :math:`(T, N, E)`. - src_mask: :math:`(S, S)`. - tgt_mask: :math:`(T, T)`. - memory_mask: :math:`(T, S)`. - src_key_padding_mask: :math:`(N, S)`. - tgt_key_padding_mask: :math:`(N, T)`. - memory_key_padding_mask: :math:`(N, S)`. Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - output: :math:`(T, N, E)`. Note: Due to the multi-head attention architecture in the transformer model, the output sequence length of a transformer is same as the input sequence (i.e. target) length of the decode. where S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number Examples: >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) """ if src.size(1) != tgt.size(1): raise RuntimeError("the batch number of src and tgt must be equal") if src.size(2) != self.d_model or tgt.size(2) != self.d_model: raise RuntimeError( "the feature number of src and tgt must be equal to d_model") memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask) output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask) return output