def reset_cell(self, batch_size: int, device: Optional[int] = None) -> None: self.hidden = [ torch.zeros(batch_size, self.hidden_dim, device=device) for _ in range(self.layers) ] self.context = [ torch.zeros(batch_size, self.hidden_dim, device=device) for _ in range(self.layers) ] if self.training and self.layer_dropout_rate > 0.0: self.layer_dropout = [ get_dropout_mask(self.layer_dropout_rate, self.hidden[i]) for i in range(self.layers) ] else: self.layer_dropout = None if self.training and self.recurrent_dropout_rate > 0.0: self.recurrent_dropout = [ get_dropout_mask(self.layer_dropout_rate, self.hidden[i]) for i in range(self.layers) ] else: self.recurrent_dropout = None
def reset_stack(self, num_stacks: int) -> None: self.stacks = [[] for _ in range(num_stacks)] self.push_buffer = [None for _ in range(num_stacks)] if self.same_dropout_mask_per_instance: if 0.0 < self.layer_dropout_probability < 1.0: self.layer_dropout_mask = [[ get_dropout_mask( self.layer_dropout_probability, torch.ones( layer.hidden_size, device=self.layer_0.input_linearity.weight.device)) for _ in range(num_stacks) ] for layer in self.rnn_layers] self.layer_dropout_mask = torch.stack( [torch.stack(l) for l in self.layer_dropout_mask]) else: self.layer_dropout_mask = None if 0.0 < self.recurrent_dropout_probability < 1.0: self.recurrent_dropout_mask = [[ get_dropout_mask( self.recurrent_dropout_probability, torch.ones( self.hidden_size, device=self.layer_0.input_linearity.weight.device)) for _ in range(num_stacks) ] for _ in range(self.num_layers)] self.recurrent_dropout_mask = torch.stack( [torch.stack(l) for l in self.recurrent_dropout_mask]) else: self.recurrent_dropout_mask = None
def forward( self, inputs: torch.FloatTensor, batch_lengths: List[int], initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): """ Parameters ---------- inputs : ``torch.FloatTensor``, required. A tensor of shape (batch_size, num_timesteps, input_size) to apply the LSTM over. batch_lengths : ``List[int]``, required. A list of length batch_size containing the lengths of the sequences in batch. initial_state : ``Tuple[torch.Tensor, torch.Tensor]``, optional, (default = None) A tuple (state, memory) representing the initial hidden state and memory of the LSTM. The ``state`` has shape (1, batch_size, hidden_size) and the ``memory`` has shape (1, batch_size, cell_size). Returns ------- output_accumulator : ``torch.FloatTensor`` The outputs of the LSTM for each timestep. A tensor of shape (batch_size, max_timesteps, hidden_size) where for a given batch element, all outputs past the sequence length for that batch are zero tensors. final_state : ``Tuple[``torch.FloatTensor, torch.FloatTensor]`` A tuple (state, memory) representing the initial hidden state and memory of the LSTM. The ``state`` has shape (1, batch_size, hidden_size) and the ``memory`` has shape (1, batch_size, cell_size). """ batch_size = inputs.size()[0] total_timesteps = inputs.size()[1] output_accumulator = inputs.new_zeros(batch_size, total_timesteps, self.hidden_size) if initial_state is None: full_batch_previous_memory = inputs.new_zeros(batch_size, self.cell_size) full_batch_previous_state = inputs.new_zeros(batch_size, self.hidden_size) else: full_batch_previous_state = initial_state[0].squeeze(0) full_batch_previous_memory = initial_state[1].squeeze(0) current_length_index = batch_size - 1 if self.go_forward else 0 if self.recurrent_dropout_probability > 0.0 and self.training: dropout_mask = get_dropout_mask( self.recurrent_dropout_probability, full_batch_previous_state ) else: dropout_mask = None for timestep in range(total_timesteps): # The index depends on which end we start. index = timestep if self.go_forward else total_timesteps - timestep - 1 # What we are doing here is finding the index into the batch dimension # which we need to use for this timestep, because the sequences have # variable length, so once the index is greater than the length of this # particular batch sequence, we no longer need to do the computation for # this sequence. The key thing to recognise here is that the batch inputs # must be _ordered_ by length from longest (first in batch) to shortest # (last) so initially, we are going forwards with every sequence and as we # pass the index at which the shortest elements of the batch finish, # we stop picking them up for the computation. if self.go_forward: while batch_lengths[current_length_index] <= index: current_length_index -= 1 # If we're going backwards, we are _picking up_ more indices. else: # First conditional: Are we already at the maximum number of elements in the batch? # Second conditional: Does the next shortest sequence beyond the current batch # index require computation use this timestep? while ( current_length_index < (len(batch_lengths) - 1) and batch_lengths[current_length_index + 1] > index ): current_length_index += 1 # Actually get the slices of the batch which we # need for the computation at this timestep. # shape (batch_size, cell_size) previous_memory = full_batch_previous_memory[0 : current_length_index + 1].clone() # Shape (batch_size, hidden_size) previous_state = full_batch_previous_state[0 : current_length_index + 1].clone() # Shape (batch_size, input_size) timestep_input = inputs[0 : current_length_index + 1, index] # Do the projections for all the gates all at once. # Both have shape (batch_size, 4 * cell_size) projected_input = self.input_linearity(timestep_input) projected_state = self.state_linearity(previous_state) # Main LSTM equations using relevant chunks of the big linear # projections of the hidden state and inputs. input_gate = torch.sigmoid( projected_input[:, (0 * self.cell_size) : (1 * self.cell_size)] + projected_state[:, (0 * self.cell_size) : (1 * self.cell_size)] ) forget_gate = torch.sigmoid( projected_input[:, (1 * self.cell_size) : (2 * self.cell_size)] + projected_state[:, (1 * self.cell_size) : (2 * self.cell_size)] ) memory_init = torch.tanh( projected_input[:, (2 * self.cell_size) : (3 * self.cell_size)] + projected_state[:, (2 * self.cell_size) : (3 * self.cell_size)] ) output_gate = torch.sigmoid( projected_input[:, (3 * self.cell_size) : (4 * self.cell_size)] + projected_state[:, (3 * self.cell_size) : (4 * self.cell_size)] ) memory = input_gate * memory_init + forget_gate * previous_memory # Here is the non-standard part of this LSTM cell; first, we clip the # memory cell, then we project the output of the timestep to a smaller size # and again clip it. if self.memory_cell_clip_value: memory = torch.clamp( memory, -self.memory_cell_clip_value, self.memory_cell_clip_value ) # shape (current_length_index, cell_size) pre_projection_timestep_output = output_gate * torch.tanh(memory) # shape (current_length_index, hidden_size) timestep_output = self.state_projection(pre_projection_timestep_output) if self.state_projection_clip_value: timestep_output = torch.clamp( timestep_output, -self.state_projection_clip_value, self.state_projection_clip_value, ) # Only do dropout if the dropout prob is > 0.0 and we are in training mode. if dropout_mask is not None: timestep_output = timestep_output * dropout_mask[0 : current_length_index + 1] # We've been doing computation with less than the full batch, so here we create a new # variable for the the whole batch at this timestep and insert the result for the # relevant elements of the batch into it. full_batch_previous_memory = full_batch_previous_memory.clone() full_batch_previous_state = full_batch_previous_state.clone() full_batch_previous_memory[0 : current_length_index + 1] = memory full_batch_previous_state[0 : current_length_index + 1] = timestep_output output_accumulator[0 : current_length_index + 1, index] = timestep_output # Mimic the pytorch API by returning state in the following shape: # (num_layers * num_directions, batch_size, ...). As this # LSTM cell cannot be stacked, the first dimension here is just 1. final_state = ( full_batch_previous_state.unsqueeze(0), full_batch_previous_memory.unsqueeze(0), ) return output_accumulator, final_state
def forward(self, inputs: torch.FloatTensor, batch_lengths: List[int], initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, prevs=None): """ Parameters ---------- inputs : ``torch.FloatTensor``, required. A tensor of shape (batch_size, num_timesteps, input_size) to apply the LSTM over. batch_lengths : ``List[int]``, required. A list of length batch_size containing the lengths of the sequences in batch. initial_state : ``Tuple[torch.Tensor, torch.Tensor]``, optional, (default = None) A tuple (state, memory) representing the initial hidden state and memory of the LSTM. The ``state`` has shape (1, batch_size, hidden_size) and the ``memory`` has shape (1, batch_size, cell_size). Returns ------- output_accumulator : ``torch.FloatTensor`` The outputs of the LSTM for each timestep. A tensor of shape (batch_size, max_timesteps, hidden_size) where for a given batch element, all outputs past the sequence length for that batch are zero tensors. final_state : ``Tuple[``torch.FloatTensor, torch.FloatTensor]`` A tuple (state, memory) representing the initial hidden state and memory of the LSTM. The ``state`` has shape (1, batch_size, hidden_size) and the ``memory`` has shape (1, batch_size, cell_size). """ batch_size = inputs.size()[0] total_timesteps = inputs.size()[1] output_accumulator = inputs.new_zeros(batch_size, total_timesteps, self.hidden_size) if initial_state is None: full_batch_previous_memory = inputs.new_zeros( batch_size, self.cell_size) full_batch_previous_state = inputs.new_zeros( batch_size, self.hidden_size) else: full_batch_previous_state = initial_state[0].squeeze(0) full_batch_previous_memory = initial_state[1].squeeze(0) current_length_index = batch_size - 1 if self.go_forward else 0 if self.recurrent_dropout_probability > 0.0 and self.training: dropout_mask = get_dropout_mask(self.recurrent_dropout_probability, full_batch_previous_state) else: dropout_mask = None hs, cs = [], [] for timestep in range(total_timesteps): # The index depends on which end we start. index = timestep if self.go_forward else total_timesteps - timestep - 1 if self.go_forward: while batch_lengths[current_length_index] <= index: current_length_index -= 1 else: while (current_length_index < (len(batch_lengths) - 1) and batch_lengths[current_length_index + 1] > index): current_length_index += 1 if timestep != 0: previous_memory, previous_state = self.get_pooled_states( timestep, prevs, hs, cs, current_length_index) else: previous_memory = full_batch_previous_memory[ 0:current_length_index + 1].clone() previous_state = full_batch_previous_state[ 0:current_length_index + 1].clone() timestep_input = inputs[0:current_length_index + 1, index] projected_input = self.input_linearity(timestep_input) projected_state = self.state_linearity(previous_state) input_gate = torch.sigmoid( projected_input[:, (0 * self.cell_size):(1 * self.cell_size)] + projected_state[:, (0 * self.cell_size):(1 * self.cell_size)]) forget_gate = torch.sigmoid( projected_input[:, (1 * self.cell_size):(2 * self.cell_size)] + projected_state[:, (1 * self.cell_size):(2 * self.cell_size)]) memory_init = torch.tanh( projected_input[:, (2 * self.cell_size):(3 * self.cell_size)] + projected_state[:, (2 * self.cell_size):(3 * self.cell_size)]) output_gate = torch.sigmoid( projected_input[:, (3 * self.cell_size):(4 * self.cell_size)] + projected_state[:, (3 * self.cell_size):(4 * self.cell_size)]) memory = input_gate * memory_init + forget_gate * previous_memory if self.memory_cell_clip_value: memory = torch.clamp(memory, -self.memory_cell_clip_value, self.memory_cell_clip_value) # shape (current_length_index, cell_size) pre_projection_timestep_output = output_gate * torch.tanh(memory) # shape (current_length_index, hidden_size) timestep_output = self.state_projection( pre_projection_timestep_output) if self.state_projection_clip_value: timestep_output = torch.clamp( timestep_output, -self.state_projection_clip_value, self.state_projection_clip_value, ) if dropout_mask is not None: timestep_output = timestep_output * dropout_mask[ 0:current_length_index + 1] full_batch_previous_memory = full_batch_previous_memory.clone() full_batch_previous_state = full_batch_previous_state.clone() full_batch_previous_memory[0:current_length_index + 1] = memory full_batch_previous_state[0:current_length_index + 1] = timestep_output hs.append(full_batch_previous_memory) cs.append(full_batch_previous_state) output_accumulator[0:current_length_index + 1, index] = timestep_output # Mimic the pytorch API by returning state in the following shape: # (num_layers * num_directions, batch_size, ...). As this # LSTM cell cannot be stacked, the first dimension here is just 1. final_state = ( full_batch_previous_state.unsqueeze(0), full_batch_previous_memory.unsqueeze(0), ) return output_accumulator, final_state
def forward( self, # pylint: disable=arguments-differ inputs: PackedSequence, initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None): """ Parameters ---------- inputs : PackedSequence, required. A tensor of shape (batch_size, num_timesteps, input_size) to apply the LSTM over. initial_state : Tuple[torch.Tensor, torch.Tensor], optional, (default = None) A tuple (state, memory) representing the initial hidden state and memory of the LSTM. Each tensor has shape (1, batch_size, output_dimension). Returns ------- A PackedSequence containing a torch.FloatTensor of shape (batch_size, num_timesteps, output_dimension) representing the outputs of the LSTM per timestep and a tuple containing the LSTM state, with shape (1, batch_size, hidden_size) to match the Pytorch API. """ if not isinstance(inputs, PackedSequence): raise ConfigurationError( 'inputs must be PackedSequence but got %s' % (type(inputs))) sequence_tensor, batch_lengths = pad_packed_sequence(inputs, batch_first=True) batch_size = sequence_tensor.size()[0] total_timesteps = sequence_tensor.size()[1] # We have to use this '.data.new().resize_.fill_' pattern to create tensors with the correct # type - forward has no knowledge of whether these are torch.Tensors or torch.cuda.Tensors. output_accumulator = Variable(sequence_tensor.data.new().resize_( batch_size, total_timesteps, self.hidden_size).fill_(0)) if initial_state is None: full_batch_previous_memory = Variable( sequence_tensor.data.new().resize_(batch_size, self.hidden_size).fill_(0)) full_batch_previous_state = Variable( sequence_tensor.data.new().resize_(batch_size, self.hidden_size).fill_(0)) else: full_batch_previous_state = initial_state[0].squeeze(0) full_batch_previous_memory = initial_state[1].squeeze(0) current_length_index = batch_size - 1 if self.go_forward else 0 if self.recurrent_dropout_probability > 0.0: dropout_mask = get_dropout_mask(self.recurrent_dropout_probability, full_batch_previous_memory) else: dropout_mask = None for timestep in range(total_timesteps): # The index depends on which end we start. index = timestep if self.go_forward else total_timesteps - timestep - 1 # What we are doing here is finding the index into the batch dimension # which we need to use for this timestep, because the sequences have # variable length, so once the index is greater than the length of this # particular batch sequence, we no longer need to do the computation for # this sequence. The key thing to recognise here is that the batch inputs # must be _ordered_ by length from longest (first in batch) to shortest # (last) so initially, we are going forwards with every sequence and as we # pass the index at which the shortest elements of the batch finish, # we stop picking them up for the computation. if self.go_forward: while batch_lengths[current_length_index] <= index: current_length_index -= 1 # If we're going backwards, we are _picking up_ more indices. else: # First conditional: Are we already at the maximum number of elements in the batch? # Second conditional: Does the next shortest sequence beyond the current batch # index require computation use this timestep? while current_length_index < (len(batch_lengths) - 1) and \ batch_lengths[current_length_index + 1] > index: current_length_index += 1 # Actually get the slices of the batch which we need for the computation at this timestep. previous_memory = full_batch_previous_memory[ 0:current_length_index + 1].clone() previous_state = full_batch_previous_state[0:current_length_index + 1].clone() timestep_input = sequence_tensor[0:current_length_index + 1, index] # Do the projections for all the gates all at once. projected_input = self.input_linearity(timestep_input) projected_state = self.state_linearity(previous_state) # Main LSTM equations using relevant chunks of the big linear # projections of the hidden state and inputs. input_gate = torch.sigmoid( projected_input[:, 0 * self.hidden_size:1 * self.hidden_size] + projected_state[:, 0 * self.hidden_size:1 * self.hidden_size]) forget_gate = torch.sigmoid( projected_input[:, 1 * self.hidden_size:2 * self.hidden_size] + projected_state[:, 1 * self.hidden_size:2 * self.hidden_size]) memory_init = torch.tanh( projected_input[:, 2 * self.hidden_size:3 * self.hidden_size] + projected_state[:, 2 * self.hidden_size:3 * self.hidden_size]) output_gate = torch.sigmoid( projected_input[:, 3 * self.hidden_size:4 * self.hidden_size] + projected_state[:, 3 * self.hidden_size:4 * self.hidden_size]) memory = input_gate * memory_init + forget_gate * previous_memory timestep_output = output_gate * torch.tanh(memory) if self.use_highway: highway_gate = torch.sigmoid( projected_input[:, 4 * self.hidden_size:5 * self.hidden_size] + projected_state[:, 4 * self.hidden_size:5 * self.hidden_size]) highway_input_projection = projected_input[:, 5 * self.hidden_size:6 * self.hidden_size] timestep_output = highway_gate * timestep_output + ( 1 - highway_gate) * highway_input_projection # Only do dropout if the dropout prob is > 0.0 and we are in training mode. if dropout_mask is not None and self.training: timestep_output = timestep_output * dropout_mask[ 0:current_length_index + 1] # We've been doing computation with less than the full batch, so here we create a new # variable for the the whole batch at this timestep and insert the result for the # relevant elements of the batch into it. full_batch_previous_memory = Variable( full_batch_previous_memory.data.clone()) full_batch_previous_state = Variable( full_batch_previous_state.data.clone()) full_batch_previous_memory[0:current_length_index + 1] = memory full_batch_previous_state[0:current_length_index + 1] = timestep_output output_accumulator[0:current_length_index + 1, index] = timestep_output output_accumulator = pack_padded_sequence(output_accumulator, batch_lengths, batch_first=True) # Mimic the pytorch API by returning state in the following shape: # (num_layers * num_directions, batch_size, hidden_size). As this # LSTM cannot be stacked, the first dimension here is just 1. final_state = (full_batch_previous_state.unsqueeze(0), full_batch_previous_memory.unsqueeze(0)) return output_accumulator, final_state
def forward( self, inputs: PackedSequence, states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None ) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]: """ Warning: Would be better to use the BiAugmentedLstm class in a regular model Given an input batch of sequential data such as word embeddings, produces a single layer unidirectional AugmentedLSTM representation of the sequential input and new state tensors. Args: inputs (PackedSequence): `bsize` sequences of shape `(len, input_dim)` each, in PackedSequence format states (Tuple[torch.Tensor, torch.Tensor]): Tuple of tensors containing the initial hidden state and the cell state of each element in the batch. Each of these tensors have a dimension of (1 x bsize x nhid). Defaults to `None`. Returns: Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]: AugmentedLSTM representation of input and the state of the LSTM `t = seq_len`. Shape of representation is (bsize x seq_len x representation_dim). Shape of each state is (1 x bsize x nhid). """ if not isinstance(inputs, PackedSequence): raise ConfigurationError( "inputs must be PackedSequence but got %s" % (type(inputs))) sequence_tensor, batch_lengths = pad_packed_sequence(inputs, batch_first=True) batch_size = sequence_tensor.size()[0] total_timesteps = sequence_tensor.size()[1] output_accumulator = sequence_tensor.new_zeros(batch_size, total_timesteps, self.lstm_dim) if states is None: full_batch_previous_memory = sequence_tensor.new_zeros( batch_size, self.lstm_dim) full_batch_previous_state = sequence_tensor.data.new_zeros( batch_size, self.lstm_dim) else: full_batch_previous_state = states[0].squeeze(0) full_batch_previous_memory = states[1].squeeze(0) current_length_index = batch_size - 1 if self.go_forward else 0 if self.recurrent_dropout_probability > 0.0: dropout_mask = get_dropout_mask(self.recurrent_dropout_probability, full_batch_previous_memory) else: dropout_mask = None for timestep in range(total_timesteps): index = timestep if self.go_forward else total_timesteps - timestep - 1 if self.go_forward: while batch_lengths[current_length_index] <= index: current_length_index -= 1 # If we're going backwards, we are _picking up_ more indices. else: # First conditional: Are we already at the maximum # number of elements in the batch? # Second conditional: Does the next shortest # sequence beyond the current batch # index require computation use this timestep? while (current_length_index < (len(batch_lengths) - 1) and batch_lengths[current_length_index + 1] > index): current_length_index += 1 previous_memory = full_batch_previous_memory[ 0:current_length_index + 1].clone() previous_state = full_batch_previous_state[0:current_length_index + 1].clone() timestep_input = sequence_tensor[0:current_length_index + 1, index] timestep_output, memory = self.cell( timestep_input, (previous_state, previous_memory), dropout_mask[0:current_length_index + 1] if dropout_mask is not None else None, ) full_batch_previous_memory = full_batch_previous_memory.data.clone( ) full_batch_previous_state = full_batch_previous_state.data.clone() full_batch_previous_memory[0:current_length_index + 1] = memory full_batch_previous_state[0:current_length_index + 1] = timestep_output output_accumulator[0:current_length_index + 1, index, :] = timestep_output output_accumulator = pack_padded_sequence(output_accumulator, batch_lengths, batch_first=True) # Mimic the pytorch API by returning state in the following shape: # (num_layers * num_directions, batch_size, lstm_dim). As this # LSTM cannot be stacked, the first dimension here is just 1. final_state = ( full_batch_previous_state.unsqueeze(0), full_batch_previous_memory.unsqueeze(0), ) return output_accumulator, final_state
def forward(self, # pylint: disable=arguments-differ inputs: PackedSequence, initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None): """ Parameters ---------- inputs : PackedSequence, required. A tensor of shape (batch_size, num_timesteps, input_size) to apply the LSTM over. initial_state : Tuple[torch.Tensor, torch.Tensor], optional, (default = None) A tuple (state, memory) representing the initial hidden state and memory of the LSTM. Each tensor has shape (1, batch_size, output_dimension). Returns ------- A PackedSequence containing a torch.FloatTensor of shape (batch_size, num_timesteps, output_dimension) representing the outputs of the LSTM per timestep and a tuple containing the LSTM state, with shape (1, batch_size, hidden_size) to match the Pytorch API. """ if not isinstance(inputs, PackedSequence): raise ConfigurationError('inputs must be PackedSequence but got %s' % (type(inputs))) sequence_tensor, batch_lengths = pad_packed_sequence(inputs, batch_first=True) batch_size = sequence_tensor.size()[0] total_timesteps = sequence_tensor.size()[1] output_accumulator = sequence_tensor.new_zeros(batch_size, total_timesteps, self.hidden_size) if initial_state is None: full_batch_previous_memory = sequence_tensor.new_zeros(batch_size, self.hidden_size) full_batch_previous_state = sequence_tensor.data.new_zeros(batch_size, self.hidden_size) else: full_batch_previous_state = initial_state[0].squeeze(0) full_batch_previous_memory = initial_state[1].squeeze(0) current_length_index = batch_size - 1 if self.go_forward else 0 if self.recurrent_dropout_probability > 0.0: dropout_mask = get_dropout_mask(self.recurrent_dropout_probability, full_batch_previous_memory) else: dropout_mask = None for timestep in range(total_timesteps): # The index depends on which end we start. index = timestep if self.go_forward else total_timesteps - timestep - 1 # What we are doing here is finding the index into the batch dimension # which we need to use for this timestep, because the sequences have # variable length, so once the index is greater than the length of this # particular batch sequence, we no longer need to do the computation for # this sequence. The key thing to recognise here is that the batch inputs # must be _ordered_ by length from longest (first in batch) to shortest # (last) so initially, we are going forwards with every sequence and as we # pass the index at which the shortest elements of the batch finish, # we stop picking them up for the computation. if self.go_forward: while batch_lengths[current_length_index] <= index: current_length_index -= 1 # If we're going backwards, we are _picking up_ more indices. else: # First conditional: Are we already at the maximum number of elements in the batch? # Second conditional: Does the next shortest sequence beyond the current batch # index require computation use this timestep? while current_length_index < (len(batch_lengths) - 1) and \ batch_lengths[current_length_index + 1] > index: current_length_index += 1 # Actually get the slices of the batch which we need for the computation at this timestep. previous_memory = full_batch_previous_memory[0: current_length_index + 1].clone() previous_state = full_batch_previous_state[0: current_length_index + 1].clone() timestep_input = sequence_tensor[0: current_length_index + 1, index] # Do the projections for all the gates all at once. projected_input = self.input_linearity(timestep_input) projected_state = self.state_linearity(previous_state) # Main LSTM equations using relevant chunks of the big linear # projections of the hidden state and inputs. input_gate = torch.sigmoid(projected_input[:, 0 * self.hidden_size:1 * self.hidden_size] + projected_state[:, 0 * self.hidden_size:1 * self.hidden_size]) forget_gate = torch.sigmoid(projected_input[:, 1 * self.hidden_size:2 * self.hidden_size] + projected_state[:, 1 * self.hidden_size:2 * self.hidden_size]) memory_init = torch.tanh(projected_input[:, 2 * self.hidden_size:3 * self.hidden_size] + projected_state[:, 2 * self.hidden_size:3 * self.hidden_size]) output_gate = torch.sigmoid(projected_input[:, 3 * self.hidden_size:4 * self.hidden_size] + projected_state[:, 3 * self.hidden_size:4 * self.hidden_size]) memory = input_gate * memory_init + forget_gate * previous_memory timestep_output = output_gate * torch.tanh(memory) if self.use_highway: highway_gate = torch.sigmoid(projected_input[:, 4 * self.hidden_size:5 * self.hidden_size] + projected_state[:, 4 * self.hidden_size:5 * self.hidden_size]) highway_input_projection = projected_input[:, 5 * self.hidden_size:6 * self.hidden_size] timestep_output = highway_gate * timestep_output + (1 - highway_gate) * highway_input_projection # Only do dropout if the dropout prob is > 0.0 and we are in training mode. if dropout_mask is not None and self.training: timestep_output = timestep_output * dropout_mask[0: current_length_index + 1] # We've been doing computation with less than the full batch, so here we create a new # variable for the the whole batch at this timestep and insert the result for the # relevant elements of the batch into it. full_batch_previous_memory = full_batch_previous_memory.data.clone() full_batch_previous_state = full_batch_previous_state.data.clone() full_batch_previous_memory[0:current_length_index + 1] = memory full_batch_previous_state[0:current_length_index + 1] = timestep_output output_accumulator[0:current_length_index + 1, index] = timestep_output output_accumulator = pack_padded_sequence(output_accumulator, batch_lengths, batch_first=True) # Mimic the pytorch API by returning state in the following shape: # (num_layers * num_directions, batch_size, hidden_size). As this # LSTM cannot be stacked, the first dimension here is just 1. final_state = (full_batch_previous_state.unsqueeze(0), full_batch_previous_memory.unsqueeze(0)) return output_accumulator, final_state
def forward(self, # pylint: disable=arguments-differ inputs: torch.FloatTensor, batch_lengths: List[int], initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None): """ Parameters ---------- inputs : ``torch.FloatTensor``, required. A tensor of shape (batch_size, num_timesteps, input_size) to apply the LSTM over. batch_lengths : ``List[int]``, required. A list of length batch_size containing the lengths of the sequences in batch. initial_state : ``Tuple[torch.Tensor, torch.Tensor]``, optional, (default = None) A tuple (state, memory) representing the initial hidden state and memory of the LSTM. The ``state`` has shape (1, batch_size, hidden_size) and the ``memory`` has shape (1, batch_size, cell_size). Returns ------- output_accumulator : ``torch.FloatTensor`` The outputs of the LSTM for each timestep. A tensor of shape (batch_size, max_timesteps, hidden_size) where for a given batch element, all outputs past the sequence length for that batch are zero tensors. final_state : ``Tuple[``torch.FloatTensor, torch.FloatTensor]`` A tuple (state, memory) representing the initial hidden state and memory of the LSTM. The ``state`` has shape (1, batch_size, hidden_size) and the ``memory`` has shape (1, batch_size, cell_size). """ batch_size = inputs.size()[0] total_timesteps = inputs.size()[1] output_accumulator = inputs.new_zeros(batch_size, total_timesteps, self.hidden_size) if initial_state is None: full_batch_previous_memory = inputs.new_zeros(batch_size, self.cell_size) full_batch_previous_state = inputs.new_zeros(batch_size, self.hidden_size) else: full_batch_previous_state = initial_state[0].squeeze(0) full_batch_previous_memory = initial_state[1].squeeze(0) current_length_index = batch_size - 1 if self.go_forward else 0 if self.recurrent_dropout_probability > 0.0 and self.training: dropout_mask = get_dropout_mask(self.recurrent_dropout_probability, full_batch_previous_state) else: dropout_mask = None for timestep in range(total_timesteps): # The index depends on which end we start. index = timestep if self.go_forward else total_timesteps - timestep - 1 # What we are doing here is finding the index into the batch dimension # which we need to use for this timestep, because the sequences have # variable length, so once the index is greater than the length of this # particular batch sequence, we no longer need to do the computation for # this sequence. The key thing to recognise here is that the batch inputs # must be _ordered_ by length from longest (first in batch) to shortest # (last) so initially, we are going forwards with every sequence and as we # pass the index at which the shortest elements of the batch finish, # we stop picking them up for the computation. if self.go_forward: while batch_lengths[current_length_index] <= index: current_length_index -= 1 # If we're going backwards, we are _picking up_ more indices. else: # First conditional: Are we already at the maximum number of elements in the batch? # Second conditional: Does the next shortest sequence beyond the current batch # index require computation use this timestep? while current_length_index < (len(batch_lengths) - 1) and \ batch_lengths[current_length_index + 1] > index: current_length_index += 1 # Actually get the slices of the batch which we # need for the computation at this timestep. # shape (batch_size, cell_size) previous_memory = full_batch_previous_memory[0: current_length_index + 1].clone() # Shape (batch_size, hidden_size) previous_state = full_batch_previous_state[0: current_length_index + 1].clone() # Shape (batch_size, input_size) timestep_input = inputs[0: current_length_index + 1, index] # Do the projections for all the gates all at once. # Both have shape (batch_size, 4 * cell_size) projected_input = self.input_linearity(timestep_input) projected_state = self.state_linearity(previous_state) # Main LSTM equations using relevant chunks of the big linear # projections of the hidden state and inputs. input_gate = torch.sigmoid(projected_input[:, (0 * self.cell_size):(1 * self.cell_size)] + projected_state[:, (0 * self.cell_size):(1 * self.cell_size)]) forget_gate = torch.sigmoid(projected_input[:, (1 * self.cell_size):(2 * self.cell_size)] + projected_state[:, (1 * self.cell_size):(2 * self.cell_size)]) memory_init = torch.tanh(projected_input[:, (2 * self.cell_size):(3 * self.cell_size)] + projected_state[:, (2 * self.cell_size):(3 * self.cell_size)]) output_gate = torch.sigmoid(projected_input[:, (3 * self.cell_size):(4 * self.cell_size)] + projected_state[:, (3 * self.cell_size):(4 * self.cell_size)]) memory = input_gate * memory_init + forget_gate * previous_memory # Here is the non-standard part of this LSTM cell; first, we clip the # memory cell, then we project the output of the timestep to a smaller size # and again clip it. if self.memory_cell_clip_value: # pylint: disable=invalid-unary-operand-type memory = torch.clamp(memory, -self.memory_cell_clip_value, self.memory_cell_clip_value) # shape (current_length_index, cell_size) pre_projection_timestep_output = output_gate * torch.tanh(memory) # shape (current_length_index, hidden_size) timestep_output = self.state_projection(pre_projection_timestep_output) if self.state_projection_clip_value: # pylint: disable=invalid-unary-operand-type timestep_output = torch.clamp(timestep_output, -self.state_projection_clip_value, self.state_projection_clip_value) # Only do dropout if the dropout prob is > 0.0 and we are in training mode. if dropout_mask is not None: timestep_output = timestep_output * dropout_mask[0: current_length_index + 1] # We've been doing computation with less than the full batch, so here we create a new # variable for the the whole batch at this timestep and insert the result for the # relevant elements of the batch into it. full_batch_previous_memory = full_batch_previous_memory.clone() full_batch_previous_state = full_batch_previous_state.clone() full_batch_previous_memory[0:current_length_index + 1] = memory full_batch_previous_state[0:current_length_index + 1] = timestep_output output_accumulator[0:current_length_index + 1, index] = timestep_output # Mimic the pytorch API by returning state in the following shape: # (num_layers * num_directions, batch_size, ...). As this # LSTM cell cannot be stacked, the first dimension here is just 1. final_state = (full_batch_previous_state.unsqueeze(0), full_batch_previous_memory.unsqueeze(0)) return output_accumulator, final_state
def _apply_push(self) -> None: index_list = [] inputs = [] initial_state = [] layer_dropout_mask = [] recurrent_dropout_mask = [] for i, (stack, buffer) in enumerate(zip(self.stacks, self.push_buffer)): if buffer is not None: index_list.append(i) inputs.append(buffer['stack_rnn_input'].unsqueeze(0)) if len(stack) > 0: initial_state.append( (stack[-1]['stack_rnn_state'].unsqueeze(1), stack[-1]['stack_rnn_memory'].unsqueeze(1))) else: initial_state.append((buffer['stack_rnn_input'].new_zeros( self.num_layers, 1, self.hidden_size), ) * 2) if self.same_dropout_mask_per_instance: if self.layer_dropout_mask is not None: layer_dropout_mask.append( self.layer_dropout_mask[:, i].unsqueeze(1)) if self.recurrent_dropout_mask is not None: recurrent_dropout_mask.append( self.recurrent_dropout_mask[:, i].unsqueeze(1)) else: if 0.0 < self.layer_dropout_probability < 1.0: layer_dropout_mask.append( get_dropout_mask( self.layer_dropout_probability, torch.ones(self.num_layers, 1, self.hidden_size, device=self.layer_0.input_linearity. weight.device))) if 0.0 < self.recurrent_dropout_probability < 1.0: recurrent_dropout_mask.append( get_dropout_mask( self.recurrent_dropout_probability, torch.ones(self.num_layers, 1, self.hidden_size, device=self.layer_0.input_linearity. weight.device))) if len(layer_dropout_mask) == 0: layer_dropout_mask = None if len(recurrent_dropout_mask) == 0: recurrent_dropout_mask = None if len(index_list) > 0: inputs = torch.cat(inputs, 0) initial_state = list(torch.cat(t, 1) for t in zip(*initial_state)) if layer_dropout_mask is not None: layer_dropout_mask = torch.cat(layer_dropout_mask, 1) if recurrent_dropout_mask is not None: recurrent_dropout_mask = torch.cat(recurrent_dropout_mask, 1) output_state, output_memory = self._forward( inputs, initial_state, layer_dropout_mask, recurrent_dropout_mask) for i, stack_index in enumerate(index_list): output = { 'stack_rnn_state': output_state[:, i, :], 'stack_rnn_memory': output_memory[:, i, :], 'stack_rnn_output': output_state[-1, i, :] } output.update(self.push_buffer[stack_index]) self.stacks[stack_index].append(output) self.push_buffer[stack_index] = None