コード例 #1
0
    def reset_cell(self,
                   batch_size: int,
                   device: Optional[int] = None) -> None:
        self.hidden = [
            torch.zeros(batch_size, self.hidden_dim, device=device)
            for _ in range(self.layers)
        ]
        self.context = [
            torch.zeros(batch_size, self.hidden_dim, device=device)
            for _ in range(self.layers)
        ]

        if self.training and self.layer_dropout_rate > 0.0:
            self.layer_dropout = [
                get_dropout_mask(self.layer_dropout_rate, self.hidden[i])
                for i in range(self.layers)
            ]
        else:
            self.layer_dropout = None

        if self.training and self.recurrent_dropout_rate > 0.0:
            self.recurrent_dropout = [
                get_dropout_mask(self.layer_dropout_rate, self.hidden[i])
                for i in range(self.layers)
            ]
        else:
            self.recurrent_dropout = None
コード例 #2
0
 def reset_stack(self, num_stacks: int) -> None:
     self.stacks = [[] for _ in range(num_stacks)]
     self.push_buffer = [None for _ in range(num_stacks)]
     if self.same_dropout_mask_per_instance:
         if 0.0 < self.layer_dropout_probability < 1.0:
             self.layer_dropout_mask = [[
                 get_dropout_mask(
                     self.layer_dropout_probability,
                     torch.ones(
                         layer.hidden_size,
                         device=self.layer_0.input_linearity.weight.device))
                 for _ in range(num_stacks)
             ] for layer in self.rnn_layers]
             self.layer_dropout_mask = torch.stack(
                 [torch.stack(l) for l in self.layer_dropout_mask])
         else:
             self.layer_dropout_mask = None
         if 0.0 < self.recurrent_dropout_probability < 1.0:
             self.recurrent_dropout_mask = [[
                 get_dropout_mask(
                     self.recurrent_dropout_probability,
                     torch.ones(
                         self.hidden_size,
                         device=self.layer_0.input_linearity.weight.device))
                 for _ in range(num_stacks)
             ] for _ in range(self.num_layers)]
             self.recurrent_dropout_mask = torch.stack(
                 [torch.stack(l) for l in self.recurrent_dropout_mask])
         else:
             self.recurrent_dropout_mask = None
コード例 #3
0
    def forward(
        self,
        inputs: torch.FloatTensor,
        batch_lengths: List[int],
        initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ):
        """
        Parameters
        ----------
        inputs : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, num_timesteps, input_size)
            to apply the LSTM over.
        batch_lengths : ``List[int]``, required.
            A list of length batch_size containing the lengths of the sequences in batch.
        initial_state : ``Tuple[torch.Tensor, torch.Tensor]``, optional, (default = None)
            A tuple (state, memory) representing the initial hidden state and memory
            of the LSTM. The ``state`` has shape (1, batch_size, hidden_size) and the
            ``memory`` has shape (1, batch_size, cell_size).

        Returns
        -------
        output_accumulator : ``torch.FloatTensor``
            The outputs of the LSTM for each timestep. A tensor of shape
            (batch_size, max_timesteps, hidden_size) where for a given batch
            element, all outputs past the sequence length for that batch are
            zero tensors.
        final_state : ``Tuple[``torch.FloatTensor, torch.FloatTensor]``
            A tuple (state, memory) representing the initial hidden state and memory
            of the LSTM. The ``state`` has shape (1, batch_size, hidden_size) and the
            ``memory`` has shape (1, batch_size, cell_size).
        """
        batch_size = inputs.size()[0]
        total_timesteps = inputs.size()[1]

        output_accumulator = inputs.new_zeros(batch_size, total_timesteps, self.hidden_size)

        if initial_state is None:
            full_batch_previous_memory = inputs.new_zeros(batch_size, self.cell_size)
            full_batch_previous_state = inputs.new_zeros(batch_size, self.hidden_size)
        else:
            full_batch_previous_state = initial_state[0].squeeze(0)
            full_batch_previous_memory = initial_state[1].squeeze(0)

        current_length_index = batch_size - 1 if self.go_forward else 0
        if self.recurrent_dropout_probability > 0.0 and self.training:
            dropout_mask = get_dropout_mask(
                self.recurrent_dropout_probability, full_batch_previous_state
            )
        else:
            dropout_mask = None

        for timestep in range(total_timesteps):
            # The index depends on which end we start.
            index = timestep if self.go_forward else total_timesteps - timestep - 1

            # What we are doing here is finding the index into the batch dimension
            # which we need to use for this timestep, because the sequences have
            # variable length, so once the index is greater than the length of this
            # particular batch sequence, we no longer need to do the computation for
            # this sequence. The key thing to recognise here is that the batch inputs
            # must be _ordered_ by length from longest (first in batch) to shortest
            # (last) so initially, we are going forwards with every sequence and as we
            # pass the index at which the shortest elements of the batch finish,
            # we stop picking them up for the computation.
            if self.go_forward:
                while batch_lengths[current_length_index] <= index:
                    current_length_index -= 1
            # If we're going backwards, we are _picking up_ more indices.
            else:
                # First conditional: Are we already at the maximum number of elements in the batch?
                # Second conditional: Does the next shortest sequence beyond the current batch
                # index require computation use this timestep?
                while (
                    current_length_index < (len(batch_lengths) - 1)
                    and batch_lengths[current_length_index + 1] > index
                ):
                    current_length_index += 1

            # Actually get the slices of the batch which we
            # need for the computation at this timestep.
            # shape (batch_size, cell_size)
            previous_memory = full_batch_previous_memory[0 : current_length_index + 1].clone()
            # Shape (batch_size, hidden_size)
            previous_state = full_batch_previous_state[0 : current_length_index + 1].clone()
            # Shape (batch_size, input_size)
            timestep_input = inputs[0 : current_length_index + 1, index]

            # Do the projections for all the gates all at once.
            # Both have shape (batch_size, 4 * cell_size)
            projected_input = self.input_linearity(timestep_input)
            projected_state = self.state_linearity(previous_state)

            # Main LSTM equations using relevant chunks of the big linear
            # projections of the hidden state and inputs.
            input_gate = torch.sigmoid(
                projected_input[:, (0 * self.cell_size) : (1 * self.cell_size)]
                + projected_state[:, (0 * self.cell_size) : (1 * self.cell_size)]
            )
            forget_gate = torch.sigmoid(
                projected_input[:, (1 * self.cell_size) : (2 * self.cell_size)]
                + projected_state[:, (1 * self.cell_size) : (2 * self.cell_size)]
            )
            memory_init = torch.tanh(
                projected_input[:, (2 * self.cell_size) : (3 * self.cell_size)]
                + projected_state[:, (2 * self.cell_size) : (3 * self.cell_size)]
            )
            output_gate = torch.sigmoid(
                projected_input[:, (3 * self.cell_size) : (4 * self.cell_size)]
                + projected_state[:, (3 * self.cell_size) : (4 * self.cell_size)]
            )
            memory = input_gate * memory_init + forget_gate * previous_memory

            # Here is the non-standard part of this LSTM cell; first, we clip the
            # memory cell, then we project the output of the timestep to a smaller size
            # and again clip it.

            if self.memory_cell_clip_value:

                memory = torch.clamp(
                    memory, -self.memory_cell_clip_value, self.memory_cell_clip_value
                )

            # shape (current_length_index, cell_size)
            pre_projection_timestep_output = output_gate * torch.tanh(memory)

            # shape (current_length_index, hidden_size)
            timestep_output = self.state_projection(pre_projection_timestep_output)
            if self.state_projection_clip_value:

                timestep_output = torch.clamp(
                    timestep_output,
                    -self.state_projection_clip_value,
                    self.state_projection_clip_value,
                )

            # Only do dropout if the dropout prob is > 0.0 and we are in training mode.
            if dropout_mask is not None:
                timestep_output = timestep_output * dropout_mask[0 : current_length_index + 1]

            # We've been doing computation with less than the full batch, so here we create a new
            # variable for the the whole batch at this timestep and insert the result for the
            # relevant elements of the batch into it.
            full_batch_previous_memory = full_batch_previous_memory.clone()
            full_batch_previous_state = full_batch_previous_state.clone()
            full_batch_previous_memory[0 : current_length_index + 1] = memory
            full_batch_previous_state[0 : current_length_index + 1] = timestep_output
            output_accumulator[0 : current_length_index + 1, index] = timestep_output

        # Mimic the pytorch API by returning state in the following shape:
        # (num_layers * num_directions, batch_size, ...). As this
        # LSTM cell cannot be stacked, the first dimension here is just 1.
        final_state = (
            full_batch_previous_state.unsqueeze(0),
            full_batch_previous_memory.unsqueeze(0),
        )

        return output_accumulator, final_state
コード例 #4
0
ファイル: modules.py プロジェクト: xinkez/Lattice-ELMo
    def forward(self,
                inputs: torch.FloatTensor,
                batch_lengths: List[int],
                initial_state: Optional[Tuple[torch.Tensor,
                                              torch.Tensor]] = None,
                prevs=None):
        """
        Parameters
        ----------
        inputs : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, num_timesteps, input_size)
            to apply the LSTM over.
        batch_lengths : ``List[int]``, required.
            A list of length batch_size containing the lengths of the sequences in batch.
        initial_state : ``Tuple[torch.Tensor, torch.Tensor]``, optional, (default = None)
            A tuple (state, memory) representing the initial hidden state and memory
            of the LSTM. The ``state`` has shape (1, batch_size, hidden_size) and the
            ``memory`` has shape (1, batch_size, cell_size).
        Returns
        -------
        output_accumulator : ``torch.FloatTensor``
            The outputs of the LSTM for each timestep. A tensor of shape
            (batch_size, max_timesteps, hidden_size) where for a given batch
            element, all outputs past the sequence length for that batch are
            zero tensors.
        final_state : ``Tuple[``torch.FloatTensor, torch.FloatTensor]``
            A tuple (state, memory) representing the initial hidden state and memory
            of the LSTM. The ``state`` has shape (1, batch_size, hidden_size) and the
            ``memory`` has shape (1, batch_size, cell_size).
        """
        batch_size = inputs.size()[0]
        total_timesteps = inputs.size()[1]

        output_accumulator = inputs.new_zeros(batch_size, total_timesteps,
                                              self.hidden_size)

        if initial_state is None:
            full_batch_previous_memory = inputs.new_zeros(
                batch_size, self.cell_size)
            full_batch_previous_state = inputs.new_zeros(
                batch_size, self.hidden_size)
        else:
            full_batch_previous_state = initial_state[0].squeeze(0)
            full_batch_previous_memory = initial_state[1].squeeze(0)

        current_length_index = batch_size - 1 if self.go_forward else 0
        if self.recurrent_dropout_probability > 0.0 and self.training:
            dropout_mask = get_dropout_mask(self.recurrent_dropout_probability,
                                            full_batch_previous_state)
        else:
            dropout_mask = None

        hs, cs = [], []

        for timestep in range(total_timesteps):
            # The index depends on which end we start.
            index = timestep if self.go_forward else total_timesteps - timestep - 1

            if self.go_forward:
                while batch_lengths[current_length_index] <= index:
                    current_length_index -= 1
            else:
                while (current_length_index < (len(batch_lengths) - 1)
                       and batch_lengths[current_length_index + 1] > index):
                    current_length_index += 1

            if timestep != 0:
                previous_memory, previous_state = self.get_pooled_states(
                    timestep, prevs, hs, cs, current_length_index)
            else:
                previous_memory = full_batch_previous_memory[
                    0:current_length_index + 1].clone()
                previous_state = full_batch_previous_state[
                    0:current_length_index + 1].clone()

            timestep_input = inputs[0:current_length_index + 1, index]

            projected_input = self.input_linearity(timestep_input)
            projected_state = self.state_linearity(previous_state)

            input_gate = torch.sigmoid(
                projected_input[:, (0 * self.cell_size):(1 * self.cell_size)] +
                projected_state[:, (0 * self.cell_size):(1 * self.cell_size)])
            forget_gate = torch.sigmoid(
                projected_input[:, (1 * self.cell_size):(2 * self.cell_size)] +
                projected_state[:, (1 * self.cell_size):(2 * self.cell_size)])
            memory_init = torch.tanh(
                projected_input[:, (2 * self.cell_size):(3 * self.cell_size)] +
                projected_state[:, (2 * self.cell_size):(3 * self.cell_size)])
            output_gate = torch.sigmoid(
                projected_input[:, (3 * self.cell_size):(4 * self.cell_size)] +
                projected_state[:, (3 * self.cell_size):(4 * self.cell_size)])
            memory = input_gate * memory_init + forget_gate * previous_memory

            if self.memory_cell_clip_value:
                memory = torch.clamp(memory, -self.memory_cell_clip_value,
                                     self.memory_cell_clip_value)

            # shape (current_length_index, cell_size)
            pre_projection_timestep_output = output_gate * torch.tanh(memory)

            # shape (current_length_index, hidden_size)
            timestep_output = self.state_projection(
                pre_projection_timestep_output)
            if self.state_projection_clip_value:
                timestep_output = torch.clamp(
                    timestep_output,
                    -self.state_projection_clip_value,
                    self.state_projection_clip_value,
                )
            if dropout_mask is not None:
                timestep_output = timestep_output * dropout_mask[
                    0:current_length_index + 1]

            full_batch_previous_memory = full_batch_previous_memory.clone()
            full_batch_previous_state = full_batch_previous_state.clone()
            full_batch_previous_memory[0:current_length_index + 1] = memory
            full_batch_previous_state[0:current_length_index +
                                      1] = timestep_output
            hs.append(full_batch_previous_memory)
            cs.append(full_batch_previous_state)
            output_accumulator[0:current_length_index + 1,
                               index] = timestep_output

        # Mimic the pytorch API by returning state in the following shape:
        # (num_layers * num_directions, batch_size, ...). As this
        # LSTM cell cannot be stacked, the first dimension here is just 1.
        final_state = (
            full_batch_previous_state.unsqueeze(0),
            full_batch_previous_memory.unsqueeze(0),
        )

        return output_accumulator, final_state
コード例 #5
0
ファイル: augmented_lstm.py プロジェクト: sbhaktha/allennlp
    def forward(
            self,  # pylint: disable=arguments-differ
            inputs: PackedSequence,
            initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
        """
        Parameters
        ----------
        inputs : PackedSequence, required.
            A tensor of shape (batch_size, num_timesteps, input_size)
            to apply the LSTM over.

        initial_state : Tuple[torch.Tensor, torch.Tensor], optional, (default = None)
            A tuple (state, memory) representing the initial hidden state and memory
            of the LSTM. Each tensor has shape (1, batch_size, output_dimension).

        Returns
        -------
        A PackedSequence containing a torch.FloatTensor of shape
        (batch_size, num_timesteps, output_dimension) representing
        the outputs of the LSTM per timestep and a tuple containing
        the LSTM state, with shape (1, batch_size, hidden_size) to
        match the Pytorch API.
        """
        if not isinstance(inputs, PackedSequence):
            raise ConfigurationError(
                'inputs must be PackedSequence but got %s' % (type(inputs)))

        sequence_tensor, batch_lengths = pad_packed_sequence(inputs,
                                                             batch_first=True)
        batch_size = sequence_tensor.size()[0]
        total_timesteps = sequence_tensor.size()[1]

        # We have to use this '.data.new().resize_.fill_' pattern to create tensors with the correct
        # type - forward has no knowledge of whether these are torch.Tensors or torch.cuda.Tensors.
        output_accumulator = Variable(sequence_tensor.data.new().resize_(
            batch_size, total_timesteps, self.hidden_size).fill_(0))
        if initial_state is None:
            full_batch_previous_memory = Variable(
                sequence_tensor.data.new().resize_(batch_size,
                                                   self.hidden_size).fill_(0))
            full_batch_previous_state = Variable(
                sequence_tensor.data.new().resize_(batch_size,
                                                   self.hidden_size).fill_(0))
        else:
            full_batch_previous_state = initial_state[0].squeeze(0)
            full_batch_previous_memory = initial_state[1].squeeze(0)

        current_length_index = batch_size - 1 if self.go_forward else 0
        if self.recurrent_dropout_probability > 0.0:
            dropout_mask = get_dropout_mask(self.recurrent_dropout_probability,
                                            full_batch_previous_memory)
        else:
            dropout_mask = None

        for timestep in range(total_timesteps):
            # The index depends on which end we start.
            index = timestep if self.go_forward else total_timesteps - timestep - 1

            # What we are doing here is finding the index into the batch dimension
            # which we need to use for this timestep, because the sequences have
            # variable length, so once the index is greater than the length of this
            # particular batch sequence, we no longer need to do the computation for
            # this sequence. The key thing to recognise here is that the batch inputs
            # must be _ordered_ by length from longest (first in batch) to shortest
            # (last) so initially, we are going forwards with every sequence and as we
            # pass the index at which the shortest elements of the batch finish,
            # we stop picking them up for the computation.
            if self.go_forward:
                while batch_lengths[current_length_index] <= index:
                    current_length_index -= 1
            # If we're going backwards, we are _picking up_ more indices.
            else:
                # First conditional: Are we already at the maximum number of elements in the batch?
                # Second conditional: Does the next shortest sequence beyond the current batch
                # index require computation use this timestep?
                while current_length_index < (len(batch_lengths) - 1) and \
                                batch_lengths[current_length_index + 1] > index:
                    current_length_index += 1

            # Actually get the slices of the batch which we need for the computation at this timestep.
            previous_memory = full_batch_previous_memory[
                0:current_length_index + 1].clone()
            previous_state = full_batch_previous_state[0:current_length_index +
                                                       1].clone()
            timestep_input = sequence_tensor[0:current_length_index + 1, index]

            # Do the projections for all the gates all at once.
            projected_input = self.input_linearity(timestep_input)
            projected_state = self.state_linearity(previous_state)

            # Main LSTM equations using relevant chunks of the big linear
            # projections of the hidden state and inputs.
            input_gate = torch.sigmoid(
                projected_input[:, 0 * self.hidden_size:1 * self.hidden_size] +
                projected_state[:, 0 * self.hidden_size:1 * self.hidden_size])
            forget_gate = torch.sigmoid(
                projected_input[:, 1 * self.hidden_size:2 * self.hidden_size] +
                projected_state[:, 1 * self.hidden_size:2 * self.hidden_size])
            memory_init = torch.tanh(
                projected_input[:, 2 * self.hidden_size:3 * self.hidden_size] +
                projected_state[:, 2 * self.hidden_size:3 * self.hidden_size])
            output_gate = torch.sigmoid(
                projected_input[:, 3 * self.hidden_size:4 * self.hidden_size] +
                projected_state[:, 3 * self.hidden_size:4 * self.hidden_size])
            memory = input_gate * memory_init + forget_gate * previous_memory
            timestep_output = output_gate * torch.tanh(memory)

            if self.use_highway:
                highway_gate = torch.sigmoid(
                    projected_input[:, 4 * self.hidden_size:5 *
                                    self.hidden_size] +
                    projected_state[:,
                                    4 * self.hidden_size:5 * self.hidden_size])
                highway_input_projection = projected_input[:, 5 *
                                                           self.hidden_size:6 *
                                                           self.hidden_size]
                timestep_output = highway_gate * timestep_output + (
                    1 - highway_gate) * highway_input_projection

            # Only do dropout if the dropout prob is > 0.0 and we are in training mode.
            if dropout_mask is not None and self.training:
                timestep_output = timestep_output * dropout_mask[
                    0:current_length_index + 1]

            # We've been doing computation with less than the full batch, so here we create a new
            # variable for the the whole batch at this timestep and insert the result for the
            # relevant elements of the batch into it.
            full_batch_previous_memory = Variable(
                full_batch_previous_memory.data.clone())
            full_batch_previous_state = Variable(
                full_batch_previous_state.data.clone())
            full_batch_previous_memory[0:current_length_index + 1] = memory
            full_batch_previous_state[0:current_length_index +
                                      1] = timestep_output
            output_accumulator[0:current_length_index + 1,
                               index] = timestep_output

        output_accumulator = pack_padded_sequence(output_accumulator,
                                                  batch_lengths,
                                                  batch_first=True)

        # Mimic the pytorch API by returning state in the following shape:
        # (num_layers * num_directions, batch_size, hidden_size). As this
        # LSTM cannot be stacked, the first dimension here is just 1.
        final_state = (full_batch_previous_state.unsqueeze(0),
                       full_batch_previous_memory.unsqueeze(0))

        return output_accumulator, final_state
コード例 #6
0
ファイル: augmented_lstm.py プロジェクト: shutianlin/allennlp
    def forward(
        self,
        inputs: PackedSequence,
        states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
    ) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Warning: Would be better to use the BiAugmentedLstm class in a regular model

        Given an input batch of sequential data such as word embeddings, produces a single layer unidirectional
        AugmentedLSTM representation of the sequential input and new state tensors.

        Args:
            inputs (PackedSequence): `bsize` sequences of shape `(len, input_dim)` each, in PackedSequence format
            states (Tuple[torch.Tensor, torch.Tensor]): Tuple of tensors containing the initial hidden state and
                the cell state of each element in the batch. Each of these tensors have a dimension of
                (1 x bsize x nhid). Defaults to `None`.

        Returns:
            Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]:
                AugmentedLSTM representation of input and the state of the LSTM `t = seq_len`.
                Shape of representation is (bsize x seq_len x representation_dim).
                Shape of each state is (1 x bsize x nhid).

        """
        if not isinstance(inputs, PackedSequence):
            raise ConfigurationError(
                "inputs must be PackedSequence but got %s" % (type(inputs)))

        sequence_tensor, batch_lengths = pad_packed_sequence(inputs,
                                                             batch_first=True)
        batch_size = sequence_tensor.size()[0]
        total_timesteps = sequence_tensor.size()[1]
        output_accumulator = sequence_tensor.new_zeros(batch_size,
                                                       total_timesteps,
                                                       self.lstm_dim)
        if states is None:
            full_batch_previous_memory = sequence_tensor.new_zeros(
                batch_size, self.lstm_dim)
            full_batch_previous_state = sequence_tensor.data.new_zeros(
                batch_size, self.lstm_dim)
        else:
            full_batch_previous_state = states[0].squeeze(0)
            full_batch_previous_memory = states[1].squeeze(0)
        current_length_index = batch_size - 1 if self.go_forward else 0
        if self.recurrent_dropout_probability > 0.0:
            dropout_mask = get_dropout_mask(self.recurrent_dropout_probability,
                                            full_batch_previous_memory)
        else:
            dropout_mask = None

        for timestep in range(total_timesteps):
            index = timestep if self.go_forward else total_timesteps - timestep - 1

            if self.go_forward:
                while batch_lengths[current_length_index] <= index:
                    current_length_index -= 1
            # If we're going backwards, we are _picking up_ more indices.
            else:
                # First conditional: Are we already at the maximum
                # number of elements in the batch?
                # Second conditional: Does the next shortest
                # sequence beyond the current batch
                # index require computation use this timestep?
                while (current_length_index < (len(batch_lengths) - 1)
                       and batch_lengths[current_length_index + 1] > index):
                    current_length_index += 1

            previous_memory = full_batch_previous_memory[
                0:current_length_index + 1].clone()
            previous_state = full_batch_previous_state[0:current_length_index +
                                                       1].clone()
            timestep_input = sequence_tensor[0:current_length_index + 1, index]
            timestep_output, memory = self.cell(
                timestep_input,
                (previous_state, previous_memory),
                dropout_mask[0:current_length_index +
                             1] if dropout_mask is not None else None,
            )
            full_batch_previous_memory = full_batch_previous_memory.data.clone(
            )
            full_batch_previous_state = full_batch_previous_state.data.clone()
            full_batch_previous_memory[0:current_length_index + 1] = memory
            full_batch_previous_state[0:current_length_index +
                                      1] = timestep_output
            output_accumulator[0:current_length_index + 1,
                               index, :] = timestep_output

        output_accumulator = pack_padded_sequence(output_accumulator,
                                                  batch_lengths,
                                                  batch_first=True)

        # Mimic the pytorch API by returning state in the following shape:
        # (num_layers * num_directions, batch_size, lstm_dim). As this
        # LSTM cannot be stacked, the first dimension here is just 1.
        final_state = (
            full_batch_previous_state.unsqueeze(0),
            full_batch_previous_memory.unsqueeze(0),
        )
        return output_accumulator, final_state
コード例 #7
0
ファイル: augmented_lstm.py プロジェクト: apmoore1/allennlp
    def forward(self,  # pylint: disable=arguments-differ
                inputs: PackedSequence,
                initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
        """
        Parameters
        ----------
        inputs : PackedSequence, required.
            A tensor of shape (batch_size, num_timesteps, input_size)
            to apply the LSTM over.

        initial_state : Tuple[torch.Tensor, torch.Tensor], optional, (default = None)
            A tuple (state, memory) representing the initial hidden state and memory
            of the LSTM. Each tensor has shape (1, batch_size, output_dimension).

        Returns
        -------
        A PackedSequence containing a torch.FloatTensor of shape
        (batch_size, num_timesteps, output_dimension) representing
        the outputs of the LSTM per timestep and a tuple containing
        the LSTM state, with shape (1, batch_size, hidden_size) to
        match the Pytorch API.
        """
        if not isinstance(inputs, PackedSequence):
            raise ConfigurationError('inputs must be PackedSequence but got %s' % (type(inputs)))

        sequence_tensor, batch_lengths = pad_packed_sequence(inputs, batch_first=True)
        batch_size = sequence_tensor.size()[0]
        total_timesteps = sequence_tensor.size()[1]

        output_accumulator = sequence_tensor.new_zeros(batch_size, total_timesteps, self.hidden_size)
        if initial_state is None:
            full_batch_previous_memory = sequence_tensor.new_zeros(batch_size, self.hidden_size)
            full_batch_previous_state = sequence_tensor.data.new_zeros(batch_size, self.hidden_size)
        else:
            full_batch_previous_state = initial_state[0].squeeze(0)
            full_batch_previous_memory = initial_state[1].squeeze(0)

        current_length_index = batch_size - 1 if self.go_forward else 0
        if self.recurrent_dropout_probability > 0.0:
            dropout_mask = get_dropout_mask(self.recurrent_dropout_probability, full_batch_previous_memory)
        else:
            dropout_mask = None

        for timestep in range(total_timesteps):
            # The index depends on which end we start.
            index = timestep if self.go_forward else total_timesteps - timestep - 1

            # What we are doing here is finding the index into the batch dimension
            # which we need to use for this timestep, because the sequences have
            # variable length, so once the index is greater than the length of this
            # particular batch sequence, we no longer need to do the computation for
            # this sequence. The key thing to recognise here is that the batch inputs
            # must be _ordered_ by length from longest (first in batch) to shortest
            # (last) so initially, we are going forwards with every sequence and as we
            # pass the index at which the shortest elements of the batch finish,
            # we stop picking them up for the computation.
            if self.go_forward:
                while batch_lengths[current_length_index] <= index:
                    current_length_index -= 1
            # If we're going backwards, we are _picking up_ more indices.
            else:
                # First conditional: Are we already at the maximum number of elements in the batch?
                # Second conditional: Does the next shortest sequence beyond the current batch
                # index require computation use this timestep?
                while current_length_index < (len(batch_lengths) - 1) and \
                                batch_lengths[current_length_index + 1] > index:
                    current_length_index += 1

            # Actually get the slices of the batch which we need for the computation at this timestep.
            previous_memory = full_batch_previous_memory[0: current_length_index + 1].clone()
            previous_state = full_batch_previous_state[0: current_length_index + 1].clone()
            timestep_input = sequence_tensor[0: current_length_index + 1, index]

            # Do the projections for all the gates all at once.
            projected_input = self.input_linearity(timestep_input)
            projected_state = self.state_linearity(previous_state)

            # Main LSTM equations using relevant chunks of the big linear
            # projections of the hidden state and inputs.
            input_gate = torch.sigmoid(projected_input[:, 0 * self.hidden_size:1 * self.hidden_size] +
                                       projected_state[:, 0 * self.hidden_size:1 * self.hidden_size])
            forget_gate = torch.sigmoid(projected_input[:, 1 * self.hidden_size:2 * self.hidden_size] +
                                        projected_state[:, 1 * self.hidden_size:2 * self.hidden_size])
            memory_init = torch.tanh(projected_input[:, 2 * self.hidden_size:3 * self.hidden_size] +
                                     projected_state[:, 2 * self.hidden_size:3 * self.hidden_size])
            output_gate = torch.sigmoid(projected_input[:, 3 * self.hidden_size:4 * self.hidden_size] +
                                        projected_state[:, 3 * self.hidden_size:4 * self.hidden_size])
            memory = input_gate * memory_init + forget_gate * previous_memory
            timestep_output = output_gate * torch.tanh(memory)

            if self.use_highway:
                highway_gate = torch.sigmoid(projected_input[:, 4 * self.hidden_size:5 * self.hidden_size] +
                                             projected_state[:, 4 * self.hidden_size:5 * self.hidden_size])
                highway_input_projection = projected_input[:, 5 * self.hidden_size:6 * self.hidden_size]
                timestep_output = highway_gate * timestep_output + (1 - highway_gate) * highway_input_projection

            # Only do dropout if the dropout prob is > 0.0 and we are in training mode.
            if dropout_mask is not None and self.training:
                timestep_output = timestep_output * dropout_mask[0: current_length_index + 1]

            # We've been doing computation with less than the full batch, so here we create a new
            # variable for the the whole batch at this timestep and insert the result for the
            # relevant elements of the batch into it.
            full_batch_previous_memory = full_batch_previous_memory.data.clone()
            full_batch_previous_state = full_batch_previous_state.data.clone()
            full_batch_previous_memory[0:current_length_index + 1] = memory
            full_batch_previous_state[0:current_length_index + 1] = timestep_output
            output_accumulator[0:current_length_index + 1, index] = timestep_output

        output_accumulator = pack_padded_sequence(output_accumulator, batch_lengths, batch_first=True)

        # Mimic the pytorch API by returning state in the following shape:
        # (num_layers * num_directions, batch_size, hidden_size). As this
        # LSTM cannot be stacked, the first dimension here is just 1.
        final_state = (full_batch_previous_state.unsqueeze(0),
                       full_batch_previous_memory.unsqueeze(0))

        return output_accumulator, final_state
コード例 #8
0
    def forward(self,  # pylint: disable=arguments-differ
                inputs: torch.FloatTensor,
                batch_lengths: List[int],
                initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
        """
        Parameters
        ----------
        inputs : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, num_timesteps, input_size)
            to apply the LSTM over.
        batch_lengths : ``List[int]``, required.
            A list of length batch_size containing the lengths of the sequences in batch.
        initial_state : ``Tuple[torch.Tensor, torch.Tensor]``, optional, (default = None)
            A tuple (state, memory) representing the initial hidden state and memory
            of the LSTM. The ``state`` has shape (1, batch_size, hidden_size) and the
            ``memory`` has shape (1, batch_size, cell_size).

        Returns
        -------
        output_accumulator : ``torch.FloatTensor``
            The outputs of the LSTM for each timestep. A tensor of shape
            (batch_size, max_timesteps, hidden_size) where for a given batch
            element, all outputs past the sequence length for that batch are
            zero tensors.
        final_state : ``Tuple[``torch.FloatTensor, torch.FloatTensor]``
            A tuple (state, memory) representing the initial hidden state and memory
            of the LSTM. The ``state`` has shape (1, batch_size, hidden_size) and the
            ``memory`` has shape (1, batch_size, cell_size).
        """
        batch_size = inputs.size()[0]
        total_timesteps = inputs.size()[1]

        output_accumulator = inputs.new_zeros(batch_size, total_timesteps, self.hidden_size)

        if initial_state is None:
            full_batch_previous_memory = inputs.new_zeros(batch_size, self.cell_size)
            full_batch_previous_state = inputs.new_zeros(batch_size, self.hidden_size)
        else:
            full_batch_previous_state = initial_state[0].squeeze(0)
            full_batch_previous_memory = initial_state[1].squeeze(0)

        current_length_index = batch_size - 1 if self.go_forward else 0
        if self.recurrent_dropout_probability > 0.0 and self.training:
            dropout_mask = get_dropout_mask(self.recurrent_dropout_probability,
                                            full_batch_previous_state)
        else:
            dropout_mask = None

        for timestep in range(total_timesteps):
            # The index depends on which end we start.
            index = timestep if self.go_forward else total_timesteps - timestep - 1

            # What we are doing here is finding the index into the batch dimension
            # which we need to use for this timestep, because the sequences have
            # variable length, so once the index is greater than the length of this
            # particular batch sequence, we no longer need to do the computation for
            # this sequence. The key thing to recognise here is that the batch inputs
            # must be _ordered_ by length from longest (first in batch) to shortest
            # (last) so initially, we are going forwards with every sequence and as we
            # pass the index at which the shortest elements of the batch finish,
            # we stop picking them up for the computation.
            if self.go_forward:
                while batch_lengths[current_length_index] <= index:
                    current_length_index -= 1
            # If we're going backwards, we are _picking up_ more indices.
            else:
                # First conditional: Are we already at the maximum number of elements in the batch?
                # Second conditional: Does the next shortest sequence beyond the current batch
                # index require computation use this timestep?
                while current_length_index < (len(batch_lengths) - 1) and \
                                batch_lengths[current_length_index + 1] > index:
                    current_length_index += 1

            # Actually get the slices of the batch which we
            # need for the computation at this timestep.
            # shape (batch_size, cell_size)
            previous_memory = full_batch_previous_memory[0: current_length_index + 1].clone()
            # Shape (batch_size, hidden_size)
            previous_state = full_batch_previous_state[0: current_length_index + 1].clone()
            # Shape (batch_size, input_size)
            timestep_input = inputs[0: current_length_index + 1, index]

            # Do the projections for all the gates all at once.
            # Both have shape (batch_size, 4 * cell_size)
            projected_input = self.input_linearity(timestep_input)
            projected_state = self.state_linearity(previous_state)

            # Main LSTM equations using relevant chunks of the big linear
            # projections of the hidden state and inputs.
            input_gate = torch.sigmoid(projected_input[:, (0 * self.cell_size):(1 * self.cell_size)] +
                                       projected_state[:, (0 * self.cell_size):(1 * self.cell_size)])
            forget_gate = torch.sigmoid(projected_input[:, (1 * self.cell_size):(2 * self.cell_size)] +
                                        projected_state[:, (1 * self.cell_size):(2 * self.cell_size)])
            memory_init = torch.tanh(projected_input[:, (2 * self.cell_size):(3 * self.cell_size)] +
                                     projected_state[:, (2 * self.cell_size):(3 * self.cell_size)])
            output_gate = torch.sigmoid(projected_input[:, (3 * self.cell_size):(4 * self.cell_size)] +
                                        projected_state[:, (3 * self.cell_size):(4 * self.cell_size)])
            memory = input_gate * memory_init + forget_gate * previous_memory

            # Here is the non-standard part of this LSTM cell; first, we clip the
            # memory cell, then we project the output of the timestep to a smaller size
            # and again clip it.

            if self.memory_cell_clip_value:
                # pylint: disable=invalid-unary-operand-type
                memory = torch.clamp(memory, -self.memory_cell_clip_value, self.memory_cell_clip_value)

            # shape (current_length_index, cell_size)
            pre_projection_timestep_output = output_gate * torch.tanh(memory)

            # shape (current_length_index, hidden_size)
            timestep_output = self.state_projection(pre_projection_timestep_output)
            if self.state_projection_clip_value:
                # pylint: disable=invalid-unary-operand-type
                timestep_output = torch.clamp(timestep_output,
                                              -self.state_projection_clip_value,
                                              self.state_projection_clip_value)

            # Only do dropout if the dropout prob is > 0.0 and we are in training mode.
            if dropout_mask is not None:
                timestep_output = timestep_output * dropout_mask[0: current_length_index + 1]

            # We've been doing computation with less than the full batch, so here we create a new
            # variable for the the whole batch at this timestep and insert the result for the
            # relevant elements of the batch into it.
            full_batch_previous_memory = full_batch_previous_memory.clone()
            full_batch_previous_state = full_batch_previous_state.clone()
            full_batch_previous_memory[0:current_length_index + 1] = memory
            full_batch_previous_state[0:current_length_index + 1] = timestep_output
            output_accumulator[0:current_length_index + 1, index] = timestep_output

        # Mimic the pytorch API by returning state in the following shape:
        # (num_layers * num_directions, batch_size, ...). As this
        # LSTM cell cannot be stacked, the first dimension here is just 1.
        final_state = (full_batch_previous_state.unsqueeze(0),
                       full_batch_previous_memory.unsqueeze(0))

        return output_accumulator, final_state
コード例 #9
0
 def _apply_push(self) -> None:
     index_list = []
     inputs = []
     initial_state = []
     layer_dropout_mask = []
     recurrent_dropout_mask = []
     for i, (stack, buffer) in enumerate(zip(self.stacks,
                                             self.push_buffer)):
         if buffer is not None:
             index_list.append(i)
             inputs.append(buffer['stack_rnn_input'].unsqueeze(0))
             if len(stack) > 0:
                 initial_state.append(
                     (stack[-1]['stack_rnn_state'].unsqueeze(1),
                      stack[-1]['stack_rnn_memory'].unsqueeze(1)))
             else:
                 initial_state.append((buffer['stack_rnn_input'].new_zeros(
                     self.num_layers, 1, self.hidden_size), ) * 2)
             if self.same_dropout_mask_per_instance:
                 if self.layer_dropout_mask is not None:
                     layer_dropout_mask.append(
                         self.layer_dropout_mask[:, i].unsqueeze(1))
                 if self.recurrent_dropout_mask is not None:
                     recurrent_dropout_mask.append(
                         self.recurrent_dropout_mask[:, i].unsqueeze(1))
             else:
                 if 0.0 < self.layer_dropout_probability < 1.0:
                     layer_dropout_mask.append(
                         get_dropout_mask(
                             self.layer_dropout_probability,
                             torch.ones(self.num_layers,
                                        1,
                                        self.hidden_size,
                                        device=self.layer_0.input_linearity.
                                        weight.device)))
                 if 0.0 < self.recurrent_dropout_probability < 1.0:
                     recurrent_dropout_mask.append(
                         get_dropout_mask(
                             self.recurrent_dropout_probability,
                             torch.ones(self.num_layers,
                                        1,
                                        self.hidden_size,
                                        device=self.layer_0.input_linearity.
                                        weight.device)))
     if len(layer_dropout_mask) == 0:
         layer_dropout_mask = None
     if len(recurrent_dropout_mask) == 0:
         recurrent_dropout_mask = None
     if len(index_list) > 0:
         inputs = torch.cat(inputs, 0)
         initial_state = list(torch.cat(t, 1) for t in zip(*initial_state))
         if layer_dropout_mask is not None:
             layer_dropout_mask = torch.cat(layer_dropout_mask, 1)
         if recurrent_dropout_mask is not None:
             recurrent_dropout_mask = torch.cat(recurrent_dropout_mask, 1)
         output_state, output_memory = self._forward(
             inputs, initial_state, layer_dropout_mask,
             recurrent_dropout_mask)
         for i, stack_index in enumerate(index_list):
             output = {
                 'stack_rnn_state': output_state[:, i, :],
                 'stack_rnn_memory': output_memory[:, i, :],
                 'stack_rnn_output': output_state[-1, i, :]
             }
             output.update(self.push_buffer[stack_index])
             self.stacks[stack_index].append(output)
             self.push_buffer[stack_index] = None