Ejemplo n.º 1
0
    def forward_impl(self, input, hx, batch_sizes, max_batch_size,
                     sorted_indices):
        # type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]  # noqa
        if hx is None:
            num_directions = 2 if self.bidirectional else 1
            zeros = torch.zeros(self.num_layers * num_directions,
                                max_batch_size,
                                self.hidden_size,
                                dtype=input.dtype,
                                device=input.device)
            hx = (zeros, zeros)
        else:
            hx = self.permute_hidden(hx, sorted_indices)

        self.check_forward_args(input, hx, batch_sizes)
        if batch_sizes is None:
            result = _VF.lstm(input, hx, self._get_flat_weights(), self.bias,
                              self.num_layers, self.dropout, self.training,
                              self.bidirectional, self.batch_first)
        else:
            result = _VF.lstm(input, batch_sizes, hx, self._get_flat_weights(),
                              self.bias, self.num_layers, self.dropout,
                              self.training, self.bidirectional)
        output = result[0]
        hidden = result[1:]

        return output, hidden
    def forward(self,
                input,
                params=None,
                hx=None,
                embeddings=None):  # noqa: F811

        if params is None:
            params = [(lambda wn: getattr(self, wn)
                       if hasattr(self, wn) else None)(wn)]

        orig_input = input
        # xxx: isinstance check needs to be in conditional for TorchScript to compile
        if isinstance(orig_input, PackedSequence):
            input, batch_sizes, sorted_indices, unsorted_indices = input
            max_batch_size = batch_sizes[0]
            max_batch_size = int(max_batch_size)
        else:
            batch_sizes = None
            max_batch_size = input.size(0) if self.batch_first else input.size(
                1)
            sorted_indices = None
            unsorted_indices = None

        if hx is None:
            num_directions = 2 if self.bidirectional else 1
            zeros = torch.zeros(self.num_layers * num_directions,
                                max_batch_size,
                                self.hidden_size,
                                dtype=input.dtype,
                                device=input.device)
            hx = (zeros, zeros)
        else:
            # Each batch of the hidden state should match the input sequence that
            # the user believes he/she is passing in.
            hx = self.permute_hidden(hx, sorted_indices)

        self.check_forward_args(input, hx, batch_sizes)
        if batch_sizes is None:
            result = _VF.lstm(input, hx, params, self.bias, self.num_layers,
                              self.dropout, self.training, self.bidirectional,
                              self.batch_first)
        else:
            result = _VF.lstm(input, batch_sizes, hx, params, bias,
                              self.num_layers, self.dropout, self.training,
                              self.bidirectional)
        output = result[0]
        hidden = result[1:]
        # xxx: isinstance check needs to be in conditional for TorchScript to compile
        if isinstance(orig_input, PackedSequence):
            output_packed = PackedSequence(output, batch_sizes, sorted_indices,
                                           unsorted_indices)
            return output_packed, self.permute_hidden(hidden, unsorted_indices)
        else:
            return output, self.permute_hidden(hidden, unsorted_indices)
Ejemplo n.º 3
0
    def forward(self, inputs, hx):
        self.flatten_parameters()
        result = _VF.lstm(inputs, hx, self._flat_weights(), self.bias, self.num_layers,
                          self.dropout, self.training, self.bidirectional, self.batch_first)
        output = result[0]
        hidden = result[1:]

        return output, hidden
Ejemplo n.º 4
0
 def forward(self, input, state=None):
     H = self.hidden_dim
     if state is None:
         state = self.new_state(input)
     state = (state[:, :H].unsqueeze(0), state[:, H:].unsqueeze(1))
     res = _VF.lstm(input, state,
                    (self.w_ih, self.w_hh, self.b_ih, self.b_hh), True, 1,
                    0, self.training, False, True)
     output, hidden1, hidden2 = res
     outstate = torch.cat([hidden1[0], hidden2[0]], 1)
     return output, outstate