Beispiel #1
0
 def check_hidden_size(
         self,
         hx: Tensor,
         expected_hidden_size: Tuple[int, int, int],
         msg: str = 'Expected hidden size {}, got {}') -> None:
     if hx.size() != expected_hidden_size:
         raise RuntimeError(
             msg.format(expected_hidden_size, tuple(hx.size())))
Beispiel #2
0
 def check_input(self, input: Tensor,
                 batch_sizes: Optional[Tensor]) -> None:
     expected_input_dim = 2 if batch_sizes is not None else 3
     if input.dim() != expected_input_dim:
         raise RuntimeError('input must have {} dimensions, got {}'.format(
             expected_input_dim, input.dim()))
     if self.input_size != input.size(-1):
         raise RuntimeError(
             'input.size(-1) must be equal to input_size. Expected {}, got {}'
             .format(self.input_size, input.size(-1)))
Beispiel #3
0
    def forward(self, input: Tensor, target: Tensor) -> _ASMoutput:
        if input.size(0) != target.size(0):
            raise RuntimeError('Input and target should have the same size '
                               'in the batch dimension.')

        used_rows = 0
        batch_size = target.size(0)

        output = input.new_zeros(batch_size)
        gather_inds = target.new_empty(batch_size)

        cutoff_values = [0] + self.cutoffs
        for i in range(len(cutoff_values) - 1):

            low_idx = cutoff_values[i]
            high_idx = cutoff_values[i + 1]

            target_mask = (target >= low_idx) & (target < high_idx)
            row_indices = target_mask.nonzero().squeeze()

            if row_indices.numel() == 0:
                continue

            if i == 0:
                gather_inds.index_copy_(0, row_indices, target[target_mask])

            else:
                relative_target = target[target_mask] - low_idx
                input_subset = input.index_select(0, row_indices)

                cluster_output = self.tail[i - 1](input_subset)
                cluster_index = self.shortlist_size + i - 1

                gather_inds.index_fill_(0, row_indices, cluster_index)

                cluster_logprob = log_softmax(cluster_output, dim=1)
                local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1))
                output.index_copy_(0, row_indices, local_logprob.squeeze(1))

            used_rows += row_indices.numel()

        if used_rows != batch_size:
            raise RuntimeError("Target values should be in [0, {}], "
                               "but values in range [{}, {}] "
                               "were found. ".format(self.n_classes - 1,
                                                     target.min().item(),
                                                     target.max().item()))

        head_output = self.head(input)
        head_logprob = log_softmax(head_output, dim=1)
        output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze()
        loss = (-output).mean()

        return _ASMoutput(output, loss)
Beispiel #4
0
 def get_expected_hidden_size(
         self, input: Tensor,
         batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
     if batch_sizes is not None:
         mini_batch = batch_sizes[0]
         mini_batch = int(mini_batch)
     else:
         mini_batch = input.size(0) if self.batch_first else input.size(1)
     num_directions = 2 if self.bidirectional else 1
     expected_hidden_size = (self.num_layers * num_directions, mini_batch,
                             self.hidden_size)
     return expected_hidden_size
Beispiel #5
0
 def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
     self.check_forward_input(input)
     if hx is None:
         hx = torch.zeros(input.size(0),
                          self.hidden_size,
                          dtype=input.dtype,
                          device=input.device)
     self.check_forward_hidden(input, hx, '')
     if self.nonlinearity == "tanh":
         ret = _VF.rnn_tanh_cell(
             input,
             hx,
             self.weight_ih,
             self.weight_hh,
             self.bias_ih,
             self.bias_hh,
         )
     elif self.nonlinearity == "relu":
         ret = _VF.rnn_relu_cell(
             input,
             hx,
             self.weight_ih,
             self.weight_hh,
             self.bias_ih,
             self.bias_hh,
         )
     else:
         ret = input  # TODO: remove when jit supports exception flow
         raise RuntimeError("Unknown nonlinearity: {}".format(
             self.nonlinearity))
     return ret
Beispiel #6
0
    def forward(self,
                input: Tensor,
                hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
        is_packed = isinstance(input, PackedSequence)
        if is_packed:
            input, batch_sizes, sorted_indices, unsorted_indices = input
            max_batch_size = batch_sizes[0]
            max_batch_size = int(max_batch_size)
        else:
            batch_sizes = None
            max_batch_size = input.size(0) if self.batch_first else input.size(
                1)
            sorted_indices = None
            unsorted_indices = None

        if hx is None:
            num_directions = 2 if self.bidirectional else 1
            hx = torch.zeros(self.num_layers * num_directions,
                             max_batch_size,
                             self.hidden_size,
                             dtype=input.dtype,
                             device=input.device)
        else:
            # Each batch of the hidden state should match the input sequence that
            # the user believes he/she is passing in.
            hx = self.permute_hidden(hx, sorted_indices)

        self.check_forward_args(input, hx, batch_sizes)
        _impl = _rnn_impls[self.mode]
        if batch_sizes is None:
            result = _impl(input, hx, self._flat_weights, self.bias,
                           self.num_layers, self.dropout, self.training,
                           self.bidirectional, self.batch_first)
        else:
            result = _impl(input, batch_sizes, hx, self._flat_weights,
                           self.bias, self.num_layers, self.dropout,
                           self.training, self.bidirectional)
        output = result[0]
        hidden = result[1]

        if is_packed:
            output = PackedSequence(output, batch_sizes, sorted_indices,
                                    unsorted_indices)
        return output, self.permute_hidden(hidden, unsorted_indices)
Beispiel #7
0
 def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
     self.check_forward_input(input)
     if hx is None:
         hx = torch.zeros(input.size(0),
                          self.hidden_size,
                          dtype=input.dtype,
                          device=input.device)
     self.check_forward_hidden(input, hx, '')
     return _VF.gru_cell(
         input,
         hx,
         self.weight_ih,
         self.weight_hh,
         self.bias_ih,
         self.bias_hh,
     )
Beispiel #8
0
    def check_forward_hidden(self,
                             input: Tensor,
                             hx: Tensor,
                             hidden_label: str = '') -> None:
        if input.size(0) != hx.size(0):
            raise RuntimeError(
                "Input batch size {} doesn't match hidden{} batch size {}".
                format(input.size(0), hidden_label, hx.size(0)))

        if hx.size(1) != self.hidden_size:
            raise RuntimeError(
                "hidden{} has inconsistent hidden_size: got {}, expected {}".
                format(hidden_label, hx.size(1), self.hidden_size))
Beispiel #9
0
 def forward(self,
             input: Tensor,
             hx: Optional[Tuple[Tensor,
                                Tensor]] = None) -> Tuple[Tensor, Tensor]:
     self.check_forward_input(input)
     if hx is None:
         zeros = torch.zeros(input.size(0),
                             self.hidden_size,
                             dtype=input.dtype,
                             device=input.device)
         hx = (zeros, zeros)
     self.check_forward_hidden(input, hx[0], '[0]')
     self.check_forward_hidden(input, hx[1], '[1]')
     return _VF.lstm_cell(
         input,
         hx,
         self.weight_ih,
         self.weight_hh,
         self.bias_ih,
         self.bias_hh,
     )
Beispiel #10
0
 def check_forward_input(self, input: Tensor) -> None:
     if input.size(1) != self.input_size:
         raise RuntimeError(
             "input has inconsistent input_size: got {}, expected {}".
             format(input.size(1), self.input_size))
Beispiel #11
0
    def forward(self,
                src: Tensor,
                tgt: Tensor,
                src_mask: Optional[Tensor] = None,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Take in and process masked source/target sequences.

        Args:
            src: the sequence to the encoder (required).
            tgt: the sequence to the decoder (required).
            src_mask: the additive mask for the src sequence (optional).
            tgt_mask: the additive mask for the tgt sequence (optional).
            memory_mask: the additive mask for the encoder output (optional).
            src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
            tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
            memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).

        Shape:
            - src: :math:`(S, N, E)`.
            - tgt: :math:`(T, N, E)`.
            - src_mask: :math:`(S, S)`.
            - tgt_mask: :math:`(T, T)`.
            - memory_mask: :math:`(T, S)`.
            - src_key_padding_mask: :math:`(N, S)`.
            - tgt_key_padding_mask: :math:`(N, T)`.
            - memory_key_padding_mask: :math:`(N, S)`.

            Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
            positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
            while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
            are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
            is provided, it will be added to the attention weight. 
            [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
            the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero
            positions will be unchanged. If a BoolTensor is provided, the positions with the
            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.

            - output: :math:`(T, N, E)`.

            Note: Due to the multi-head attention architecture in the transformer model,
            the output sequence length of a transformer is same as the input sequence
            (i.e. target) length of the decode.

            where S is the source sequence length, T is the target sequence length, N is the
            batch size, E is the feature number

        Examples:
            >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
        """

        if src.size(1) != tgt.size(1):
            raise RuntimeError("the batch number of src and tgt must be equal")

        if src.size(2) != self.d_model or tgt.size(2) != self.d_model:
            raise RuntimeError(
                "the feature number of src and tgt must be equal to d_model")

        memory = self.encoder(src,
                              mask=src_mask,
                              src_key_padding_mask=src_key_padding_mask)
        output = self.decoder(tgt,
                              memory,
                              tgt_mask=tgt_mask,
                              memory_mask=memory_mask,
                              tgt_key_padding_mask=tgt_key_padding_mask,
                              memory_key_padding_mask=memory_key_padding_mask)
        return output